diff --git a/.gitignore b/.gitignore index 818a264d..59f463de 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,8 @@ LocalAI local-ai # prevent above rules from omitting the helm chart !charts/* +# prevent above rules from omitting the api/localai folder +!api/localai # Ignore models models/* diff --git a/Makefile b/Makefile index 752370e5..6fdf0685 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ GOVET=$(GOCMD) vet BINARY_NAME=local-ai # llama.cpp versions -GOLLAMA_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7 +GOLLAMA_VERSION?=f03869d188b72c8a617bea3a36cf8eb43f73445c # gpt4all version GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all diff --git a/api/api.go b/api/api.go index 66a1db09..de18e182 100644 --- a/api/api.go +++ b/api/api.go @@ -2,6 +2,7 @@ package api import ( "errors" + "fmt" "strings" config "github.com/go-skynet/LocalAI/api/config" @@ -19,7 +20,7 @@ import ( "github.com/rs/zerolog/log" ) -func App(opts ...options.AppOption) (*fiber.App, error) { +func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { options := options.NewOptions(opts...) zerolog.SetGlobalLevel(zerolog.InfoLevel) @@ -27,6 +28,65 @@ func App(opts ...options.AppOption) (*fiber.App, error) { zerolog.SetGlobalLevel(zerolog.DebugLevel) } + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) + log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) + + cl := config.NewConfigLoader() + if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { + log.Error().Msgf("error loading config files: %s", err.Error()) + } + + if options.ConfigFile != "" { + if err := cl.LoadConfigFile(options.ConfigFile); err != nil { + log.Error().Msgf("error loading config file: %s", err.Error()) + } + } + + if options.Debug { + for _, v := range cl.ListConfigs() { + cfg, _ := cl.GetConfig(v) + log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) + } + } + + if options.AssetsDestination != "" { + // Extract files from the embedded FS + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) + if err != nil { + log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) + } + } + + if options.PreloadJSONModels != "" { + if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { + return nil, nil, err + } + } + + if options.PreloadModelsFromPath != "" { + if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { + return nil, nil, err + } + } + + // turn off any process that was started by GRPC if the context is canceled + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + options.Loader.StopAllGRPC() + }() + + return options, cl, nil +} + +func App(opts ...options.AppOption) (*fiber.App, error) { + + options, cl, err := Startup(opts...) + if err != nil { + return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error()) + } + // Return errors as JSON responses app := fiber.New(fiber.Config{ BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB @@ -57,36 +117,6 @@ func App(opts ...options.AppOption) (*fiber.App, error) { })) } - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) - log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - - cm := config.NewConfigLoader() - if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil { - log.Error().Msgf("error loading config files: %s", err.Error()) - } - - if options.ConfigFile != "" { - if err := cm.LoadConfigFile(options.ConfigFile); err != nil { - log.Error().Msgf("error loading config file: %s", err.Error()) - } - } - - if options.Debug { - for _, v := range cm.ListConfigs() { - cfg, _ := cm.GetConfig(v) - log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) - } - } - - if options.AssetsDestination != "" { - // Extract files from the embedded FS - err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) - log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) - if err != nil { - log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) - } - } - // Default middleware config app.Use(recover.New()) @@ -116,18 +146,6 @@ func App(opts ...options.AppOption) (*fiber.App, error) { return c.Next() } - if options.PreloadJSONModels != "" { - if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil { - return nil, err - } - } - - if options.PreloadModelsFromPath != "" { - if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil { - return nil, err - } - } - if options.CORS { var c func(ctx *fiber.Ctx) error if options.CORSAllowOrigins == "" { @@ -141,7 +159,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // LocalAI API endpoints galleryService := localai.NewGalleryService(options.Loader.ModelPath) - galleryService.Start(options.Context, cm) + galleryService.Start(options.Context, cl) app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { @@ -149,36 +167,36 @@ func App(opts ...options.AppOption) (*fiber.App, error) { }{Version: internal.PrintableVersion()}) }) - app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries)) + app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cl, galleryService.C, options.Galleries)) app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath)) app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService)) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cm, options)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cm, options)) + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cm, options)) - app.Post("/edits", auth, openai.EditEndpoint(cm, options)) + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) + app.Post("/edits", auth, openai.EditEndpoint(cl, options)) // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cm, options)) - app.Post("/completions", auth, openai.CompletionEndpoint(cm, options)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cm, options)) + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) + app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cm, options)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cm, options)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cm, options)) + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cm, options)) - app.Post("/tts", auth, localai.TTSEndpoint(cm, options)) + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) + app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cm, options)) + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) if options.ImageDir != "" { app.Static("/generated-images", options.ImageDir) @@ -196,16 +214,13 @@ func App(opts ...options.AppOption) (*fiber.App, error) { app.Get("/healthz", ok) app.Get("/readyz", ok) - // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cm)) - app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cm)) + // Experimental Backend Statistics Module + backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now + app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) - // turn off any process that was started by GRPC if the context is canceled - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - options.Loader.StopGRPC() - }() + // models + app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) + app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) return app, nil } diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index 554cb11e..aa1e393f 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -2,7 +2,6 @@ package backend import ( "fmt" - "sync" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" @@ -88,18 +87,6 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. } return func() ([]float32, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - embeds, err := fn() if err != nil { return embeds, err diff --git a/api/backend/image.go b/api/backend/image.go index 81303926..9c9ad6c0 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -1,8 +1,6 @@ package backend import ( - "sync" - config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" @@ -67,19 +65,5 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat return err } - return func() error { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[c.Backend] - if !ok { - m := &sync.Mutex{} - mutexes[c.Backend] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - return fn() - }, nil + return fn, nil } diff --git a/api/backend/llm.go b/api/backend/llm.go index 80067e7c..c30e0f81 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -15,7 +15,17 @@ import ( "github.com/go-skynet/LocalAI/pkg/utils" ) -func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { +type LLMResponse struct { + Response string // should this be []byte? + Usage TokenUsage +} + +type TokenUsage struct { + Prompt int + Completion int +} + +func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model grpcOpts := gRPCModelOpts(c) @@ -70,40 +80,56 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c } // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported - fn := func() (string, error) { + fn := func() (LLMResponse, error) { opts := gRPCPredictOpts(c, loader.ModelPath) opts.Prompt = s + + tokenUsage := TokenUsage{} + + // check the per-model feature flag for usage, since tokenCallback may have a cost, but default to on. + if !c.FeatureFlag["usage"] { + userTokenCallback := tokenCallback + if userTokenCallback == nil { + userTokenCallback = func(token string, usage TokenUsage) bool { + return true + } + } + + promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) + if pErr == nil && promptInfo.Length > 0 { + tokenUsage.Prompt = int(promptInfo.Length) + } + + tokenCallback = func(token string, usage TokenUsage) bool { + tokenUsage.Completion++ + return userTokenCallback(token, tokenUsage) + } + } + if tokenCallback != nil { ss := "" err := inferenceModel.PredictStream(ctx, opts, func(s []byte) { - tokenCallback(string(s)) + tokenCallback(string(s), tokenUsage) ss += string(s) }) - return ss, err + return LLMResponse{ + Response: ss, + Usage: tokenUsage, + }, err } else { + // TODO: Is the chicken bit the only way to get here? is that acceptable? reply, err := inferenceModel.Predict(ctx, opts) if err != nil { - return "", err + return LLMResponse{}, err } - return string(reply.Message), err + return LLMResponse{ + Response: string(reply.Message), + Usage: tokenUsage, + }, err } } - return func() (string, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - return fn() - }, nil + return fn, nil } var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) diff --git a/api/backend/lock.go b/api/backend/lock.go deleted file mode 100644 index 6b4f577c..00000000 --- a/api/backend/lock.go +++ /dev/null @@ -1,22 +0,0 @@ -package backend - -import "sync" - -// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 -var mutexMap sync.Mutex -var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) - -func Lock(s string) *sync.Mutex { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[s] - if !ok { - m := &sync.Mutex{} - mutexes[s] = m - l = m - } - mutexMap.Unlock() - l.Lock() - - return l -} diff --git a/api/config/config.go b/api/config/config.go index 744e3723..24a6658e 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -29,6 +29,7 @@ type Config struct { FunctionsConfig Functions `yaml:"function"` + FeatureFlag map[string]bool `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. // LLM configs (GPT4ALL, Llama.cpp, ...) LLMConfig `yaml:",inline"` diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go new file mode 100644 index 00000000..f723cddf --- /dev/null +++ b/api/localai/backend_monitor.go @@ -0,0 +1,142 @@ +package localai + +import ( + "context" + "fmt" + "strings" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + + gopsutil "github.com/shirou/gopsutil/v3/process" +) + +type BackendMonitorRequest struct { + Model string `json:"model" yaml:"model"` +} + +type BackendMonitorResponse struct { + MemoryInfo *gopsutil.MemoryInfoStat + MemoryPercent float32 + CPUPercent float64 +} + +type BackendMonitor struct { + configLoader *config.ConfigLoader + options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. +} + +func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor { + return BackendMonitor{ + configLoader: configLoader, + options: options, + } +} + +func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) { + config, exists := bm.configLoader.GetConfig(model) + var backend string + if exists { + backend = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backend = model + } + + if !strings.HasSuffix(backend, ".bin") { + backend = fmt.Sprintf("%s.bin", backend) + } + + pid, err := bm.options.Loader.GetGRPCPID(backend) + + if err != nil { + log.Error().Msgf("model %s : failed to find pid %+v", model, err) + return nil, err + } + + // Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID. + backendProcess, err := gopsutil.NewProcess(int32(pid)) + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err) + return nil, err + } + + memInfo, err := backendProcess.MemoryInfo() + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err) + return nil, err + } + + memPercent, err := backendProcess.MemoryPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err) + return nil, err + } + + cpuPercent, err := backendProcess.CPUPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err) + return nil, err + } + + return &BackendMonitorResponse{ + MemoryInfo: memInfo, + MemoryPercent: memPercent, + CPUPercent: cpuPercent, + }, nil +} + +func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(BackendMonitorRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + config, exists := bm.configLoader.GetConfig(input.Model) + var backendId string + if exists { + backendId = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backendId = input.Model + } + + if !strings.HasSuffix(backendId, ".bin") { + backendId = fmt.Sprintf("%s.bin", backendId) + } + + client := bm.options.Loader.CheckIsLoaded(backendId) + + if client == nil { + return fmt.Errorf("backend %s is not currently loaded", input.Model) + } + + status, rpcErr := client.Status(context.TODO()) + if rpcErr != nil { + log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", input.Model, rpcErr.Error()) + val, slbErr := bm.SampleLocalBackendProcess(backendId) + if slbErr != nil { + return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", input.Model, rpcErr.Error(), slbErr.Error()) + } + return c.JSON(proto.StatusResponse{ + State: proto.StatusResponse_ERROR, + Memory: &proto.MemoryUsageData{ + Total: val.MemoryInfo.VMS, + Breakdown: map[string]uint64{ + "gopsutil-RSS": val.MemoryInfo.RSS, + }, + }, + }) + } + + return c.JSON(status) + } +} diff --git a/api/openai/chat.go b/api/openai/chat.go index a38e2bf3..6393e5d8 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -29,11 +29,16 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } responses <- initialMessage - ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool { resp := OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, Object: "chat.completion.chunk", + Usage: OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, } responses <- resp @@ -237,11 +242,13 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + usage := &OpenAIUsage{} + for ev := range responses { + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it var buf bytes.Buffer enc := json.NewEncoder(&buf) enc.Encode(ev) - log.Debug().Msgf("Sending chunk: %s", buf.String()) _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) if err != nil { @@ -261,6 +268,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) Delta: &Message{Content: &emptyMessage}, }}, Object: "chat.completion.chunk", + Usage: *usage, } respData, _ := json.Marshal(resp) @@ -271,7 +279,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return nil } - result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) { + result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) { if processFunctions { // As we have to change the result before processing, we can't stream the answer (yet?) ss := map[string]interface{}{} @@ -327,8 +335,8 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return } - prediction = backend.Finetune(*config, predInput, prediction) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) + fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &fineTunedResponse}}) } else { // otherwise reply with the function call *c = append(*c, Choice{ @@ -349,6 +357,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "chat.completion", + Usage: OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, } respData, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", respData) diff --git a/api/openai/completion.go b/api/openai/completion.go index 19d24be1..20d15d4e 100644 --- a/api/openai/completion.go +++ b/api/openai/completion.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" model "github.com/go-skynet/LocalAI/pkg/model" @@ -18,7 +19,7 @@ import ( // https://platform.openai.com/docs/api-reference/completions func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool { resp := OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{ @@ -28,6 +29,11 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe }, }, Object: "text_completion", + Usage: OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, } log.Debug().Msgf("Sending goroutine: %s", s) @@ -120,6 +126,9 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe } var result []Choice + + totalTokenUsage := backend.TokenUsage{} + for k, i := range config.PromptStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ @@ -131,13 +140,16 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Template found, input modified to: %s", i) } - r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { + r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k}) }, nil) if err != nil { return err } + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + result = append(result, r...) } @@ -145,6 +157,11 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "text_completion", + Usage: OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, } jsonResult, _ := json.Marshal(resp) diff --git a/api/openai/edit.go b/api/openai/edit.go index ef37131a..6b4664df 100644 --- a/api/openai/edit.go +++ b/api/openai/edit.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" model "github.com/go-skynet/LocalAI/pkg/model" @@ -32,6 +33,8 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } var result []Choice + totalTokenUsage := backend.TokenUsage{} + for _, i := range config.InputStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ @@ -44,13 +47,16 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) log.Debug().Msgf("Template found, input modified to: %s", i) } - r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { + r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s}) }, nil) if err != nil { return err } + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + result = append(result, r...) } @@ -58,6 +64,11 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "edit", + Usage: OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, } jsonResult, _ := json.Marshal(resp) diff --git a/api/openai/inference.go b/api/openai/inference.go index 68d7ae85..2f34d82e 100644 --- a/api/openai/inference.go +++ b/api/openai/inference.go @@ -7,8 +7,8 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { - n := req.N +func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string, backend.TokenUsage) bool) ([]Choice, backend.TokenUsage, error) { + n := req.N // number of completions to return result := []Choice{} if n == 0 { @@ -18,20 +18,25 @@ func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, // get the model function to call for the result predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback) if err != nil { - return result, err + return result, backend.TokenUsage{}, err } + tokenUsage := backend.TokenUsage{} + for i := 0; i < n; i++ { prediction, err := predFunc() if err != nil { - return result, err + return result, backend.TokenUsage{}, err } - prediction = backend.Finetune(*config, predInput, prediction) - cb(prediction, &result) + tokenUsage.Prompt += prediction.Usage.Prompt + tokenUsage.Completion += prediction.Usage.Completion + + finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) + cb(finetunedResponse, &result) //result = append(result, Choice{Text: prediction}) } - return result, err + return result, tokenUsage, err } diff --git a/go.mod b/go.mod index 844df7f2..84339e9b 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,17 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +require ( + github.com/go-ole/go-ole v1.2.6 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/shirou/gopsutil/v3 v3.23.6 + github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/tklauser/go-sysconf v0.3.11 // indirect + github.com/tklauser/numcpus v0.6.0 // indirect + github.com/yusufpapurcu/wmi v1.2.3 // indirect +) + require ( github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect @@ -50,7 +61,6 @@ require ( github.com/pkoukk/tiktoken-go v0.1.2 // indirect github.com/ulikunitz/xz v0.5.9 // indirect github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect - google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/fsnotify.v1 v1.4.7 // indirect diff --git a/go.sum b/go.sum index 36e3cf93..b085ba67 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674 h1:G70Yf/QOCEL1v24idWnGd6rJsbqiGkJAJnMaWaolzEg= -github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df h1:qVcBEZlvp5A1gGWNJj02xyDtbsUI2hohlQMSB1fgER4= github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY= @@ -33,24 +31,14 @@ github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa h1:gxr68r/6EWroay4iI81jxqGCDbKotY4+CiwdUkBz2NQ= github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc= github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e h1:4reMY29i1eOZaRaSTMPNyXI7X8RMNxCTfDDBXYzrbr0= github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-llama.cpp v0.0.0-20230709163512-6c97625cca76 h1:NRdxo2MKi8qhWZXxu6CIZOkdH+LBERFz1kk22U1FD3k= -github.com/go-skynet/go-llama.cpp v0.0.0-20230709163512-6c97625cca76/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230724222459-562d2b5a7119 h1:FeUSk5yMHT7J7jeCQKAOs4x5LRNSYH0SR6djM/i1jcc= -github.com/go-skynet/go-llama.cpp v0.0.0-20230724222459-562d2b5a7119/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA= -github.com/go-skynet/go-llama.cpp v0.0.0-20230727163958-6ba16de8e965 h1:2MO/rABKpkXnnKQ3Ar90aqhnlMEejE9gnKG6bafv+ow= -github.com/go-skynet/go-llama.cpp v0.0.0-20230727163958-6ba16de8e965/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA= -github.com/go-skynet/go-llama.cpp v0.0.0-20230729200103-8c51308e42d7 h1:1uBwholTaJ8Lva8ySJjT4jNaCDAh+MJXtsbZBbQq9lA= -github.com/go-skynet/go-llama.cpp v0.0.0-20230729200103-8c51308e42d7/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA= -github.com/go-skynet/go-llama.cpp v0.0.0-20230802220037-50cee7712066 h1:v4Js+yEdgY9IV7n35M+5MELLxlOMp3qC5whZm5YTLjI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230802220037-50cee7712066/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA= -github.com/go-skynet/go-llama.cpp v0.0.0-20230814195654-18f25c21abf9 h1:62wpzDHwjZGfIfimvve3bNrS6/gOLkSfwsCjcSD6g8U= -github.com/go-skynet/go-llama.cpp v0.0.0-20230814195654-18f25c21abf9/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA= github.com/go-skynet/go-llama.cpp v0.0.0-20230815201253-f03869d188b7 h1:d/FXe1a55gCLf124uRYYtlYg6KvI7OI33xaFejQUAws= github.com/go-skynet/go-llama.cpp v0.0.0-20230815201253-f03869d188b7/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= @@ -76,6 +64,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -107,6 +96,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -130,22 +121,6 @@ github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d h1:/lAg9v github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d/go.mod h1:HGGAOJhipApckwNV8ZTliRJqxctUv3xRY+zbQEwuytc= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230714185456-cfd70b69fcf5 h1:bmQnxyKiqCu8i2y/N/Sf0coWoG2/Ed12YGQeb7lTnjo= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230714185456-cfd70b69fcf5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230725212419-9100b2ef6fb9 h1:/oRwZhulKTU8LpPD2fXi2o2kdlTutQjYWDVMkrv14po= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230725212419-9100b2ef6fb9/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230727161923-39acbc837816 h1:hRi7hpDUuaO0dB4NZ8eyaeD2fRar6CPyNAARsO5DhzA= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230727161923-39acbc837816/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230731161838-cbdcde8b7586 h1:WVEMSZMyHFe68PN204c3Fdk5g2lZouPvbU9/2zkPpWc= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230731161838-cbdcde8b7586/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230802145814-c449b71b56de h1:E5EGczxEAcbaO8yqj074MQxU609QbtB6in3qTOW1EFo= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230802145814-c449b71b56de/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230807175413-0f2bb506a8ee h1:Y/j+GNytyncmDnAEuDZwzkYC9nzUPvXJPF+nntQG0VU= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230807175413-0f2bb506a8ee/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230811181453-4d855afe973a h1:bX26Zfwh72ug2aZTEwFISTMEJ56Wa/4KqboidD+g92A= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230811181453-4d855afe973a/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230814164545-4e55940edf11 h1:72DoTIAcKXEv5Q5MSaHFCpVAQHqwU84wUsxy/UcdKTc= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230814164545-4e55940edf11/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5 h1:b4EeYDaGxOLNlNm5LOVEmrUhaw1v6xq/V79ZwWVlY6I= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ= @@ -162,8 +137,6 @@ github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7 github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= -github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc= -github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ= github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg= @@ -178,39 +151,37 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= -github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.14.0 h1:D1yAB+DHElgbJFdYyjxfTWMFzhddn+PwZmkQ039L7mQ= -github.com/sashabaranov/go-openai v1.14.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= -github.com/sashabaranov/go-openai v1.14.1 h1:jqfkdj8XHnBF84oi2aNtT8Ktp3EJ0MfuVjvcMkfI0LA= -github.com/sashabaranov/go-openai v1.14.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.14.2 h1:5DPTtR9JBjKPJS008/A409I5ntFhUPPGCmaAihcPRyo= github.com/sashabaranov/go-openai v1.14.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/shirou/gopsutil/v3 v3.23.6 h1:5y46WPI9QBKBbK7EEccUPNXpJpNrvPuTD0O2zHEHT08= +github.com/shirou/gopsutil/v3 v3.23.6/go.mod h1:j7QX50DrXYggrpN30W0Mo+I4/8U2UUIQrnrhqUeWrAU= +github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= +github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= +github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= +github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/tmc/langchaingo v0.0.0-20230713201705-dcf7ecdc8ac8 h1:wdJigYmmIRCuXhCkADDr53Oa1fp/WlxCPoVXR2r7GrU= -github.com/tmc/langchaingo v0.0.0-20230713201705-dcf7ecdc8ac8/go.mod h1:mTzgQfAGwmBz2hhQELZfu2bwsbHwyKHA6IHOa+9LDFg= -github.com/tmc/langchaingo v0.0.0-20230726025230-7d5f9fd5e90a h1:I/2JSuYXkWaVVLSZmrPfrgbvvvPR0IaulZcB0Iu8oVI= -github.com/tmc/langchaingo v0.0.0-20230726025230-7d5f9fd5e90a/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E= -github.com/tmc/langchaingo v0.0.0-20230729232647-7df4fe5fb8fe h1:+XVrCjh3rPibfISkUFG2Ck5NLKODQ9cFdmraFye1bGA= -github.com/tmc/langchaingo v0.0.0-20230729232647-7df4fe5fb8fe/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E= -github.com/tmc/langchaingo v0.0.0-20230731024823-8f101609f600 h1:SABuIthjhIXEsxnokuA16CZOxxdW9XohIHQqd/go8Nc= -github.com/tmc/langchaingo v0.0.0-20230731024823-8f101609f600/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E= -github.com/tmc/langchaingo v0.0.0-20230802030916-271e9bd7e7c5 h1:js7vYDJGzUGVSt0YlIusUc5BXYVECu3LUI/asby5Ggo= -github.com/tmc/langchaingo v0.0.0-20230802030916-271e9bd7e7c5/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E= -github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537 h1:vkeNjlW+0Xiw2XizMHoQuLG8pg6AN1hU8zJuMV9GQBc= -github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM= +github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI= +github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms= +github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4= github.com/tmc/langchaingo v0.0.0-20230815194031-eb0cbd31327d h1:RBu2wOoyzxNxYTitUKVNDtU1H6T4Tu5skOwvZabnPFc= github.com/tmc/langchaingo v0.0.0-20230815194031-eb0cbd31327d/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E= github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= @@ -229,6 +200,8 @@ github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMx github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= +github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -240,8 +213,6 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -251,26 +222,28 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -282,14 +255,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= -google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 h1:9NWlQfY2ePejTmfwUH1OWwmznFa+0kKcHGPDvcPza9M= -google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54/go.mod h1:zqTuNwFlFRsw5zIts5VnzLQxSRqh+CGOTVMlYbY0Eyk= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM= google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= -google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI= -google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw= google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/main.go b/main.go index 4f4b824c..8f5e6445 100644 --- a/main.go +++ b/main.go @@ -135,6 +135,12 @@ func main() { Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.", EnvVars: []string{"API_KEY"}, }, + &cli.BoolFlag{ + Name: "preload-backend-only", + Usage: "If set, the api is NOT launched, and only the preloaded models / backends are started. This is intended for multi-node setups.", + EnvVars: []string{"PRELOAD_BACKEND_ONLY"}, + Value: false, + }, }, Description: ` LocalAI is a drop-in replacement OpenAI API which runs inference locally. @@ -187,6 +193,11 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit opts = append(opts, options.EnableGalleriesAutoload) } + if ctx.Bool("preload-backend-only") { + _, _, err := api.Startup(opts...) + return err + } + app, err := api.App(opts...) if err != nil { return err diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index a6d89f2b..ffce63c7 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -4,17 +4,39 @@ package base // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "fmt" + "os" + "sync" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" + gopsutil "github.com/shirou/gopsutil/v3/process" ) type Base struct { + backendBusy sync.Mutex + State pb.StatusResponse_State +} + +func (llm *Base) Busy() bool { + r := llm.backendBusy.TryLock() + if r { + llm.backendBusy.Unlock() + } + return r +} + +func (llm *Base) Lock() { + llm.backendBusy.Lock() + llm.State = pb.StatusResponse_BUSY +} + +func (llm *Base) Unlock() { + llm.State = pb.StatusResponse_READY + llm.backendBusy.Unlock() } func (llm *Base) Load(opts *pb.ModelOptions) error { return fmt.Errorf("unimplemented") - } func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) { @@ -40,3 +62,32 @@ func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) { func (llm *Base) TTS(*pb.TTSRequest) error { return fmt.Errorf("unimplemented") } + +func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { + return pb.TokenizationResponse{}, fmt.Errorf("unimplemented") +} + +// backends may wish to call this to capture the gopsutil info, then enhance with additional memory usage details? +func (llm *Base) Status() (pb.StatusResponse, error) { + + mud := pb.MemoryUsageData{ + Breakdown: make(map[string]uint64), + } + + pid := int32(os.Getpid()) + + backendProcess, err := gopsutil.NewProcess(pid) + + if err == nil { + memInfo, err := backendProcess.MemoryInfo() + if err == nil { + mud.Total = memInfo.VMS // TEST, but rss seems reasonable first guess. Does include swap, but we might care about that. + mud.Breakdown["gopsutil-RSS"] = memInfo.RSS + } + } + + return pb.StatusResponse{ + State: llm.State, + Memory: &mud, + }, nil +} diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index fe845a64..cdc34ad5 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -158,3 +158,29 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques tresult.Text = res.Text return tresult, err } + +func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + res, err := client.TokenizeString(ctx, in, opts...) + + if err != nil { + return nil, err + } + return res, nil +} + +func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.Status(ctx, &pb.HealthMessage{}) +} diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index b5713bed..6c46f764 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -6,6 +6,7 @@ import ( ) type LLM interface { + Busy() bool Predict(*pb.PredictOptions) (string, error) PredictStream(*pb.PredictOptions, chan string) error Load(*pb.ModelOptions) error @@ -13,6 +14,8 @@ type LLM interface { GenerateImage(*pb.GenerateImageRequest) error AudioTranscription(*pb.TranscriptRequest) (api.Result, error) TTS(*pb.TTSRequest) error + TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) + Status() (pb.StatusResponse, error) } func newReply(s string) *pb.Reply { diff --git a/pkg/grpc/llm/bert/bert.go b/pkg/grpc/llm/bert/bert.go index b7df5d76..abdf0102 100644 --- a/pkg/grpc/llm/bert/bert.go +++ b/pkg/grpc/llm/bert/bert.go @@ -4,6 +4,7 @@ package bert // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( bert "github.com/go-skynet/go-bert.cpp" + "github.com/rs/zerolog/log" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" @@ -15,12 +16,21 @@ type Embeddings struct { } func (llm *Embeddings) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("bert backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := bert.New(opts.ModelFile) llm.bert = model return err } func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + llm.Base.Lock() + defer llm.Base.Unlock() + if len(opts.EmbeddingTokens) > 0 { tokens := []int{} for _, t := range opts.EmbeddingTokens { diff --git a/pkg/grpc/llm/bloomz/bloomz.go b/pkg/grpc/llm/bloomz/bloomz.go index fecaa5b7..304bab30 100644 --- a/pkg/grpc/llm/bloomz/bloomz.go +++ b/pkg/grpc/llm/bloomz/bloomz.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" "github.com/go-skynet/bloomz.cpp" ) @@ -18,6 +19,12 @@ type LLM struct { } func (llm *LLM) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("bloomz backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := bloomz.New(opts.ModelFile) llm.bloomz = model return err @@ -40,11 +47,16 @@ func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() + return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() + go func() { res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -53,6 +65,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro } results <- res close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go index d287ef6d..c2b9d03f 100644 --- a/pkg/grpc/llm/falcon/falcon.go +++ b/pkg/grpc/llm/falcon/falcon.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" ggllm "github.com/mudler/go-ggllm.cpp" ) @@ -18,6 +19,13 @@ type LLM struct { } func (llm *LLM) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() + ggllmOpts := []ggllm.ModelOption{} if opts.ContextSize != 0 { ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize))) @@ -118,10 +126,14 @@ func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() + predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { @@ -138,6 +150,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro fmt.Println("err: ", err) } close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/gpt4all/gpt4all.go b/pkg/grpc/llm/gpt4all/gpt4all.go index f94e3309..0b485120 100644 --- a/pkg/grpc/llm/gpt4all/gpt4all.go +++ b/pkg/grpc/llm/gpt4all/gpt4all.go @@ -8,6 +8,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" + "github.com/rs/zerolog/log" ) type LLM struct { @@ -17,6 +18,13 @@ type LLM struct { } func (llm *LLM) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("gpt4all backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() + model, err := gpt4all.New(opts.ModelFile, gpt4all.SetThreads(int(opts.Threads)), gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath)) @@ -39,10 +47,15 @@ func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() + return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() + predictOptions := buildPredictOptions(opts) go func() { @@ -56,6 +69,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro } llm.gpt4all.SetTokenCallback(nil) close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/langchain/langchain.go b/pkg/grpc/llm/langchain/langchain.go index 5d5f94bd..cd3fd12b 100644 --- a/pkg/grpc/llm/langchain/langchain.go +++ b/pkg/grpc/llm/langchain/langchain.go @@ -8,6 +8,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/langchain" + "github.com/rs/zerolog/log" ) type LLM struct { @@ -18,12 +19,21 @@ type LLM struct { } func (llm *LLM) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("langchain backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() llm.langchain, _ = langchain.NewHuggingFace(opts.Model) llm.model = opts.Model return nil } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() + o := []langchain.PredictOption{ langchain.SetModel(llm.model), langchain.SetMaxTokens(int(opts.Tokens)), @@ -38,6 +48,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() o := []langchain.PredictOption{ langchain.SetModel(llm.model), langchain.SetMaxTokens(int(opts.Tokens)), @@ -52,6 +63,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro } results <- res.Completion close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/llama/llama.go b/pkg/grpc/llm/llama/llama.go index da46c9f5..594dfb97 100644 --- a/pkg/grpc/llm/llama/llama.go +++ b/pkg/grpc/llm/llama/llama.go @@ -8,6 +8,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/go-llama.cpp" + "github.com/rs/zerolog/log" ) type LLM struct { @@ -18,6 +19,13 @@ type LLM struct { func (llm *LLM) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("llama backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() + ropeFreqBase := float32(10000) ropeFreqScale := float32(1) @@ -73,6 +81,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { model, err := llama.New(opts.ModelFile, llamaOpts...) llm.llama = model + return err } @@ -167,10 +176,14 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() + predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { @@ -184,12 +197,16 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro fmt.Println("err: ", err) } close(results) + llm.Base.Unlock() }() return nil } func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + llm.Base.Lock() + defer llm.Base.Unlock() + predictOptions := buildPredictOptions(opts) if len(opts.EmbeddingTokens) > 0 { @@ -202,3 +219,18 @@ func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { return llm.llama.Embeddings(opts.Embeddings, predictOptions...) } + +func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { + llm.Base.Lock() + defer llm.Base.Unlock() + + predictOptions := buildPredictOptions(opts) + l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...) + if err != nil { + return pb.TokenizationResponse{}, err + } + return pb.TokenizationResponse{ + Length: l, + Tokens: tokens, + }, nil +} diff --git a/pkg/grpc/llm/rwkv/rwkv.go b/pkg/grpc/llm/rwkv/rwkv.go index cfd60f0c..3658befb 100644 --- a/pkg/grpc/llm/rwkv/rwkv.go +++ b/pkg/grpc/llm/rwkv/rwkv.go @@ -9,6 +9,7 @@ import ( "github.com/donomii/go-rwkv.cpp" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" ) const tokenizerSuffix = ".tokenizer.json" @@ -20,6 +21,12 @@ type LLM struct { } func (llm *LLM) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("rwkv backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() modelPath := filepath.Dir(opts.ModelFile) modelFile := filepath.Base(opts.ModelFile) model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads())) @@ -32,6 +39,8 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() stopWord := "\n" if len(opts.StopPrompts) > 0 { @@ -48,6 +57,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { stopWord := "\n" @@ -65,6 +75,7 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro return true }) close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/transformers/dolly.go b/pkg/grpc/llm/transformers/dolly.go index 6c9e0a5d..220490a7 100644 --- a/pkg/grpc/llm/transformers/dolly.go +++ b/pkg/grpc/llm/transformers/dolly.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,27 @@ type Dolly struct { } func (llm *Dolly) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("dolly backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.NewDolly(opts.ModelFile) llm.dolly = model return err } func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() + go func() { res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +48,7 @@ func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) er } results <- res close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/transformers/falcon.go b/pkg/grpc/llm/transformers/falcon.go index 54e9b320..fceb10c4 100644 --- a/pkg/grpc/llm/transformers/falcon.go +++ b/pkg/grpc/llm/transformers/falcon.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,26 @@ type Falcon struct { } func (llm *Falcon) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("transformers-falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.NewFalcon(opts.ModelFile) llm.falcon = model return err } func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +47,7 @@ func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) e } results <- res close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/transformers/gpt2.go b/pkg/grpc/llm/transformers/gpt2.go index 66517d7c..53b364e9 100644 --- a/pkg/grpc/llm/transformers/gpt2.go +++ b/pkg/grpc/llm/transformers/gpt2.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,26 @@ type GPT2 struct { } func (llm *GPT2) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("gpt2 backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.New(opts.ModelFile) llm.gpt2 = model return err } func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +47,7 @@ func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) err } results <- res close(results) + llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/gptj.go b/pkg/grpc/llm/transformers/gptj.go index 6f692188..c798c3df 100644 --- a/pkg/grpc/llm/transformers/gptj.go +++ b/pkg/grpc/llm/transformers/gptj.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,26 @@ type GPTJ struct { } func (llm *GPTJ) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("gptj backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.NewGPTJ(opts.ModelFile) llm.gptj = model return err } func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +47,7 @@ func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) err } results <- res close(results) + llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/gptneox.go b/pkg/grpc/llm/transformers/gptneox.go index bcff0abe..bcaa8da6 100644 --- a/pkg/grpc/llm/transformers/gptneox.go +++ b/pkg/grpc/llm/transformers/gptneox.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,26 @@ type GPTNeoX struct { } func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("gptneox backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.NewGPTNeoX(opts.ModelFile) llm.gptneox = model return err } func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +47,7 @@ func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) } results <- res close(results) + llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/mpt.go b/pkg/grpc/llm/transformers/mpt.go index 8adda361..1b9272ee 100644 --- a/pkg/grpc/llm/transformers/mpt.go +++ b/pkg/grpc/llm/transformers/mpt.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,27 @@ type MPT struct { } func (llm *MPT) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("mpt backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.NewMPT(opts.ModelFile) llm.mpt = model return err } func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() + return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +48,7 @@ func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) erro } results <- res close(results) + llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/replit.go b/pkg/grpc/llm/transformers/replit.go index bfc3a90a..0c1fc066 100644 --- a/pkg/grpc/llm/transformers/replit.go +++ b/pkg/grpc/llm/transformers/replit.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,26 @@ type Replit struct { } func (llm *Replit) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("replit backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.NewReplit(opts.ModelFile) llm.replit = model return err } func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +47,7 @@ func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) e } results <- res close(results) + llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/starcoder.go b/pkg/grpc/llm/transformers/starcoder.go index 06948cb3..c63256f9 100644 --- a/pkg/grpc/llm/transformers/starcoder.go +++ b/pkg/grpc/llm/transformers/starcoder.go @@ -7,6 +7,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) @@ -18,17 +19,26 @@ type Starcoder struct { } func (llm *Starcoder) Load(opts *pb.ModelOptions) error { + if llm.Base.State != pb.StatusResponse_UNINITIALIZED { + log.Warn().Msgf("starcoder backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) + } + + llm.Base.Lock() + defer llm.Base.Unlock() model, err := transformers.NewStarcoder(opts.ModelFile) llm.starcoder = model return err } func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { + llm.Base.Lock() + defer llm.Base.Unlock() return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error { + llm.Base.Lock() go func() { res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -37,6 +47,7 @@ func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string } results <- res close(results) + llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/proto/backend.pb.go b/pkg/grpc/proto/backend.pb.go index 87456b13..dd53df52 100644 --- a/pkg/grpc/proto/backend.pb.go +++ b/pkg/grpc/proto/backend.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.26.0 -// protoc v3.15.8 +// protoc-gen-go v1.27.1 +// protoc v3.12.4 // source: pkg/grpc/proto/backend.proto package proto @@ -20,6 +20,58 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type StatusResponse_State int32 + +const ( + StatusResponse_UNINITIALIZED StatusResponse_State = 0 + StatusResponse_BUSY StatusResponse_State = 1 + StatusResponse_READY StatusResponse_State = 2 + StatusResponse_ERROR StatusResponse_State = -1 +) + +// Enum value maps for StatusResponse_State. +var ( + StatusResponse_State_name = map[int32]string{ + 0: "UNINITIALIZED", + 1: "BUSY", + 2: "READY", + -1: "ERROR", + } + StatusResponse_State_value = map[string]int32{ + "UNINITIALIZED": 0, + "BUSY": 1, + "READY": 2, + "ERROR": -1, + } +) + +func (x StatusResponse_State) Enum() *StatusResponse_State { + p := new(StatusResponse_State) + *p = x + return p +} + +func (x StatusResponse_State) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (StatusResponse_State) Descriptor() protoreflect.EnumDescriptor { + return file_pkg_grpc_proto_backend_proto_enumTypes[0].Descriptor() +} + +func (StatusResponse_State) Type() protoreflect.EnumType { + return &file_pkg_grpc_proto_backend_proto_enumTypes[0] +} + +func (x StatusResponse_State) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use StatusResponse_State.Descriptor instead. +func (StatusResponse_State) EnumDescriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{13, 0} +} + type HealthMessage struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1253,6 +1305,171 @@ func (x *TTSRequest) GetDst() string { return "" } +type TokenizationResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Length int32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"` + Tokens []int32 `protobuf:"varint,2,rep,packed,name=tokens,proto3" json:"tokens,omitempty"` +} + +func (x *TokenizationResponse) Reset() { + *x = TokenizationResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TokenizationResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TokenizationResponse) ProtoMessage() {} + +func (x *TokenizationResponse) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TokenizationResponse.ProtoReflect.Descriptor instead. +func (*TokenizationResponse) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{11} +} + +func (x *TokenizationResponse) GetLength() int32 { + if x != nil { + return x.Length + } + return 0 +} + +func (x *TokenizationResponse) GetTokens() []int32 { + if x != nil { + return x.Tokens + } + return nil +} + +type MemoryUsageData struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Total uint64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"` + Breakdown map[string]uint64 `protobuf:"bytes,2,rep,name=breakdown,proto3" json:"breakdown,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"varint,2,opt,name=value,proto3"` +} + +func (x *MemoryUsageData) Reset() { + *x = MemoryUsageData{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MemoryUsageData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MemoryUsageData) ProtoMessage() {} + +func (x *MemoryUsageData) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[12] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MemoryUsageData.ProtoReflect.Descriptor instead. +func (*MemoryUsageData) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{12} +} + +func (x *MemoryUsageData) GetTotal() uint64 { + if x != nil { + return x.Total + } + return 0 +} + +func (x *MemoryUsageData) GetBreakdown() map[string]uint64 { + if x != nil { + return x.Breakdown + } + return nil +} + +type StatusResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + State StatusResponse_State `protobuf:"varint,1,opt,name=state,proto3,enum=backend.StatusResponse_State" json:"state,omitempty"` + Memory *MemoryUsageData `protobuf:"bytes,2,opt,name=memory,proto3" json:"memory,omitempty"` +} + +func (x *StatusResponse) Reset() { + *x = StatusResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StatusResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StatusResponse) ProtoMessage() {} + +func (x *StatusResponse) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[13] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead. +func (*StatusResponse) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{13} +} + +func (x *StatusResponse) GetState() StatusResponse_State { + if x != nil { + return x.State + } + return StatusResponse_UNINITIALIZED +} + +func (x *StatusResponse) GetMemory() *MemoryUsageData { + if x != nil { + return x.Memory + } + return nil +} + var File_pkg_grpc_proto_backend_proto protoreflect.FileDescriptor var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{ @@ -1451,44 +1668,80 @@ var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{ 0x04, 0x74, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x32, 0xeb, 0x03, 0x0a, 0x07, 0x42, 0x61, - 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, 0x32, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, - 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, - 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x07, 0x50, 0x72, 0x65, - 0x64, 0x69, 0x63, 0x74, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, - 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x22, 0x46, 0x0a, 0x14, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, 0x6b, + 0x65, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x05, 0x52, 0x06, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x73, 0x22, 0xac, 0x01, 0x0a, 0x0f, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55, 0x73, 0x61, 0x67, + 0x65, 0x44, 0x61, 0x74, 0x61, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x12, 0x45, 0x0a, 0x09, 0x62, + 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, 0x77, 0x6e, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x27, + 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55, + 0x73, 0x61, 0x67, 0x65, 0x44, 0x61, 0x74, 0x61, 0x2e, 0x42, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, + 0x77, 0x6e, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x62, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, + 0x77, 0x6e, 0x1a, 0x3c, 0x0a, 0x0e, 0x42, 0x72, 0x65, 0x61, 0x6b, 0x64, 0x6f, 0x77, 0x6e, 0x45, + 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, + 0x22, 0xbc, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x33, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x12, 0x30, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f, + 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, + 0x6e, 0x64, 0x2e, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x55, 0x73, 0x61, 0x67, 0x65, 0x44, 0x61, + 0x74, 0x61, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x22, 0x43, 0x0a, 0x05, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x12, 0x11, 0x0a, 0x0d, 0x55, 0x4e, 0x49, 0x4e, 0x49, 0x54, 0x49, 0x41, 0x4c, + 0x49, 0x5a, 0x45, 0x44, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x42, 0x55, 0x53, 0x59, 0x10, 0x01, + 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x41, 0x44, 0x59, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x05, 0x45, + 0x52, 0x52, 0x4f, 0x52, 0x10, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x32, + 0xf4, 0x04, 0x0a, 0x07, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, 0x32, 0x0a, 0x06, 0x48, + 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, + 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, - 0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62, - 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, - 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, - 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, + 0x34, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, + 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, + 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, + 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x4d, 0x6f, 0x64, + 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x0d, + 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e, + 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, + 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x40, 0x0a, 0x09, 0x45, 0x6d, + 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, - 0x22, 0x00, 0x30, 0x01, 0x12, 0x40, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, - 0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, - 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63, - 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, - 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x41, 0x0a, 0x0d, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, - 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, - 0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, - 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64, - 0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, - 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61, - 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, - 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x03, 0x54, 0x54, 0x53, 0x12, - 0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x54, 0x53, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, - 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b, - 0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63, - 0x6b, 0x65, 0x6e, 0x64, 0x42, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63, - 0x6b, 0x65, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, - 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, - 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, + 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x41, 0x0a, 0x0d, + 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x2e, + 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, + 0x49, 0x6d, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, + 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, + 0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64, 0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, + 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, + 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x2d, + 0x0a, 0x03, 0x54, 0x54, 0x53, 0x12, 0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, + 0x54, 0x54, 0x53, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, + 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x4a, 0x0a, + 0x0e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x65, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x12, + 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, + 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, + 0x6e, 0x64, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65, + 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x17, 0x2e, 0x62, 0x61, + 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, + 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x42, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, + 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -1503,43 +1756,56 @@ func file_pkg_grpc_proto_backend_proto_rawDescGZIP() []byte { return file_pkg_grpc_proto_backend_proto_rawDescData } -var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 11) +var file_pkg_grpc_proto_backend_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 15) var file_pkg_grpc_proto_backend_proto_goTypes = []interface{}{ - (*HealthMessage)(nil), // 0: backend.HealthMessage - (*PredictOptions)(nil), // 1: backend.PredictOptions - (*Reply)(nil), // 2: backend.Reply - (*ModelOptions)(nil), // 3: backend.ModelOptions - (*Result)(nil), // 4: backend.Result - (*EmbeddingResult)(nil), // 5: backend.EmbeddingResult - (*TranscriptRequest)(nil), // 6: backend.TranscriptRequest - (*TranscriptResult)(nil), // 7: backend.TranscriptResult - (*TranscriptSegment)(nil), // 8: backend.TranscriptSegment - (*GenerateImageRequest)(nil), // 9: backend.GenerateImageRequest - (*TTSRequest)(nil), // 10: backend.TTSRequest + (StatusResponse_State)(0), // 0: backend.StatusResponse.State + (*HealthMessage)(nil), // 1: backend.HealthMessage + (*PredictOptions)(nil), // 2: backend.PredictOptions + (*Reply)(nil), // 3: backend.Reply + (*ModelOptions)(nil), // 4: backend.ModelOptions + (*Result)(nil), // 5: backend.Result + (*EmbeddingResult)(nil), // 6: backend.EmbeddingResult + (*TranscriptRequest)(nil), // 7: backend.TranscriptRequest + (*TranscriptResult)(nil), // 8: backend.TranscriptResult + (*TranscriptSegment)(nil), // 9: backend.TranscriptSegment + (*GenerateImageRequest)(nil), // 10: backend.GenerateImageRequest + (*TTSRequest)(nil), // 11: backend.TTSRequest + (*TokenizationResponse)(nil), // 12: backend.TokenizationResponse + (*MemoryUsageData)(nil), // 13: backend.MemoryUsageData + (*StatusResponse)(nil), // 14: backend.StatusResponse + nil, // 15: backend.MemoryUsageData.BreakdownEntry } var file_pkg_grpc_proto_backend_proto_depIdxs = []int32{ - 8, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment - 0, // 1: backend.Backend.Health:input_type -> backend.HealthMessage - 1, // 2: backend.Backend.Predict:input_type -> backend.PredictOptions - 3, // 3: backend.Backend.LoadModel:input_type -> backend.ModelOptions - 1, // 4: backend.Backend.PredictStream:input_type -> backend.PredictOptions - 1, // 5: backend.Backend.Embedding:input_type -> backend.PredictOptions - 9, // 6: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest - 6, // 7: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest - 10, // 8: backend.Backend.TTS:input_type -> backend.TTSRequest - 2, // 9: backend.Backend.Health:output_type -> backend.Reply - 2, // 10: backend.Backend.Predict:output_type -> backend.Reply - 4, // 11: backend.Backend.LoadModel:output_type -> backend.Result - 2, // 12: backend.Backend.PredictStream:output_type -> backend.Reply - 5, // 13: backend.Backend.Embedding:output_type -> backend.EmbeddingResult - 4, // 14: backend.Backend.GenerateImage:output_type -> backend.Result - 7, // 15: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult - 4, // 16: backend.Backend.TTS:output_type -> backend.Result - 9, // [9:17] is the sub-list for method output_type - 1, // [1:9] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 9, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment + 15, // 1: backend.MemoryUsageData.breakdown:type_name -> backend.MemoryUsageData.BreakdownEntry + 0, // 2: backend.StatusResponse.state:type_name -> backend.StatusResponse.State + 13, // 3: backend.StatusResponse.memory:type_name -> backend.MemoryUsageData + 1, // 4: backend.Backend.Health:input_type -> backend.HealthMessage + 2, // 5: backend.Backend.Predict:input_type -> backend.PredictOptions + 4, // 6: backend.Backend.LoadModel:input_type -> backend.ModelOptions + 2, // 7: backend.Backend.PredictStream:input_type -> backend.PredictOptions + 2, // 8: backend.Backend.Embedding:input_type -> backend.PredictOptions + 10, // 9: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest + 7, // 10: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest + 11, // 11: backend.Backend.TTS:input_type -> backend.TTSRequest + 2, // 12: backend.Backend.TokenizeString:input_type -> backend.PredictOptions + 1, // 13: backend.Backend.Status:input_type -> backend.HealthMessage + 3, // 14: backend.Backend.Health:output_type -> backend.Reply + 3, // 15: backend.Backend.Predict:output_type -> backend.Reply + 5, // 16: backend.Backend.LoadModel:output_type -> backend.Result + 3, // 17: backend.Backend.PredictStream:output_type -> backend.Reply + 6, // 18: backend.Backend.Embedding:output_type -> backend.EmbeddingResult + 5, // 19: backend.Backend.GenerateImage:output_type -> backend.Result + 8, // 20: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult + 5, // 21: backend.Backend.TTS:output_type -> backend.Result + 12, // 22: backend.Backend.TokenizeString:output_type -> backend.TokenizationResponse + 14, // 23: backend.Backend.Status:output_type -> backend.StatusResponse + 14, // [14:24] is the sub-list for method output_type + 4, // [4:14] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_pkg_grpc_proto_backend_proto_init() } @@ -1680,19 +1946,56 @@ func file_pkg_grpc_proto_backend_proto_init() { return nil } } + file_pkg_grpc_proto_backend_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TokenizationResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MemoryUsageData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StatusResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_pkg_grpc_proto_backend_proto_rawDesc, - NumEnums: 0, - NumMessages: 11, + NumEnums: 1, + NumMessages: 15, NumExtensions: 0, NumServices: 1, }, GoTypes: file_pkg_grpc_proto_backend_proto_goTypes, DependencyIndexes: file_pkg_grpc_proto_backend_proto_depIdxs, + EnumInfos: file_pkg_grpc_proto_backend_proto_enumTypes, MessageInfos: file_pkg_grpc_proto_backend_proto_msgTypes, }.Build() File_pkg_grpc_proto_backend_proto = out.File diff --git a/pkg/grpc/proto/backend.proto b/pkg/grpc/proto/backend.proto index 30b309f5..b83d7526 100644 --- a/pkg/grpc/proto/backend.proto +++ b/pkg/grpc/proto/backend.proto @@ -16,6 +16,8 @@ service Backend { rpc GenerateImage(GenerateImageRequest) returns (Result) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} rpc TTS(TTSRequest) returns (Result) {} + rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {} + rpc Status(HealthMessage) returns (StatusResponse) {} } message HealthMessage {} @@ -157,3 +159,24 @@ message TTSRequest { string model = 2; string dst = 3; } + +message TokenizationResponse { + int32 length = 1; + repeated int32 tokens = 2; +} + +message MemoryUsageData { + uint64 total = 1; + map breakdown = 2; +} + +message StatusResponse { + enum State { + UNINITIALIZED = 0; + BUSY = 1; + READY = 2; + ERROR = -1; + } + State state = 1; + MemoryUsageData memory = 2; +} \ No newline at end of file diff --git a/pkg/grpc/proto/backend_grpc.pb.go b/pkg/grpc/proto/backend_grpc.pb.go index b9d7dd8b..bb865949 100644 --- a/pkg/grpc/proto/backend_grpc.pb.go +++ b/pkg/grpc/proto/backend_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.2.0 -// - protoc v3.15.8 +// - protoc-gen-go-grpc v1.3.0 +// - protoc v3.12.4 // source: pkg/grpc/proto/backend.proto package proto @@ -18,6 +18,19 @@ import ( // Requires gRPC-Go v1.32.0 or later. const _ = grpc.SupportPackageIsVersion7 +const ( + Backend_Health_FullMethodName = "/backend.Backend/Health" + Backend_Predict_FullMethodName = "/backend.Backend/Predict" + Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel" + Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream" + Backend_Embedding_FullMethodName = "/backend.Backend/Embedding" + Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage" + Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription" + Backend_TTS_FullMethodName = "/backend.Backend/TTS" + Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString" + Backend_Status_FullMethodName = "/backend.Backend/Status" +) + // BackendClient is the client API for Backend service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. @@ -30,6 +43,8 @@ type BackendClient interface { GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) + TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) + Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) } type backendClient struct { @@ -42,7 +57,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient { func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { out := new(Reply) - err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -51,7 +66,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { out := new(Reply) - err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -60,7 +75,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts .. func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -68,7 +83,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts .. } func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { - stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) + stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...) if err != nil { return nil, err } @@ -101,7 +116,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) { func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { out := new(EmbeddingResult) - err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -110,7 +125,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -119,7 +134,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { out := new(TranscriptResult) - err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -128,7 +143,25 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) { + out := new(TokenizationResponse) + err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) { + out := new(StatusResponse) + err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -147,6 +180,8 @@ type BackendServer interface { GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) TTS(context.Context, *TTSRequest) (*Result, error) + TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error) + Status(context.Context, *HealthMessage) (*StatusResponse, error) mustEmbedUnimplementedBackendServer() } @@ -178,6 +213,12 @@ func (UnimplementedBackendServer) AudioTranscription(context.Context, *Transcrip func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) { return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented") } +func (UnimplementedBackendServer) TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method TokenizeString not implemented") +} +func (UnimplementedBackendServer) Status(context.Context, *HealthMessage) (*StatusResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Status not implemented") +} func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {} // UnsafeBackendServer may be embedded to opt out of forward compatibility for this service. @@ -201,7 +242,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/Health", + FullMethod: Backend_Health_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) @@ -219,7 +260,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/Predict", + FullMethod: Backend_Predict_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) @@ -237,7 +278,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/LoadModel", + FullMethod: Backend_LoadModel_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) @@ -276,7 +317,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/Embedding", + FullMethod: Backend_Embedding_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) @@ -294,7 +335,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/GenerateImage", + FullMethod: Backend_GenerateImage_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) @@ -312,7 +353,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/AudioTranscription", + FullMethod: Backend_AudioTranscription_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) @@ -330,7 +371,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/TTS", + FullMethod: Backend_TTS_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) @@ -338,6 +379,42 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa return interceptor(ctx, in, info, handler) } +func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).TokenizeString(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Backend_TokenizeString_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).Status(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Backend_Status_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).Status(ctx, req.(*HealthMessage)) + } + return interceptor(ctx, in, info, handler) +} + // Backend_ServiceDesc is the grpc.ServiceDesc for Backend service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -373,6 +450,14 @@ var Backend_ServiceDesc = grpc.ServiceDesc{ MethodName: "TTS", Handler: _Backend_TTS_Handler, }, + { + MethodName: "TokenizeString", + Handler: _Backend_TokenizeString_Handler, + }, + { + MethodName: "Status", + Handler: _Backend_Status_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 45e7d143..58ea4e7e 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -110,6 +110,32 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS return nil } +func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) { + res, err := s.llm.TokenizeString(in) + if err != nil { + return nil, err + } + + castTokens := make([]int32, len(res.Tokens)) + for i, v := range res.Tokens { + castTokens[i] = int32(v) + } + + return &pb.TokenizationResponse{ + Length: int32(res.Length), + Tokens: castTokens, + }, err +} + +func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) { + res, err := s.llm.Status() + if err != nil { + return nil, err + } + + return &res, nil +} + func StartServer(address string, model LLM) error { lis, err := net.Listen("tcp", address) if err != nil { diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 49c472f7..14135a98 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "strings" "syscall" "time" @@ -64,10 +65,33 @@ var AutoLoadBackends []string = []string{ PiperBackend, } -func (ml *ModelLoader) StopGRPC() { - for _, p := range ml.grpcProcesses { - p.Stop() +func (ml *ModelLoader) GetGRPCPID(id string) (int, error) { + p, exists := ml.grpcProcesses[id] + if !exists { + return -1, fmt.Errorf("no grpc backend found for %s", id) } + return strconv.Atoi(p.PID) +} + +type GRPCProcessFilter = func(p *process.Process) bool + +func includeAllProcesses(_ *process.Process) bool { + return true +} + +func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) { + for _, p := range ml.grpcProcesses { + if filter(p) { + p.Stop() + } + } +} + +func (ml *ModelLoader) StopAllGRPC() { + ml.StopGRPC(includeAllProcesses) + // for _, p := range ml.grpcProcesses { + // p.Stop() + // } } func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error { @@ -252,7 +276,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { // Is this really needed? BackendLoader already does this ml.mu.Lock() - if m := ml.checkIsLoaded(o.model); m != nil { + if m := ml.CheckIsLoaded(o.model); m != nil { log.Debug().Msgf("Model '%s' already loaded", o.model) ml.mu.Unlock() return m, nil diff --git a/pkg/model/loader.go b/pkg/model/loader.go index b45a52c8..4191cea1 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -103,7 +103,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( defer ml.mu.Unlock() // Check if we already have a loaded model - if model := ml.checkIsLoaded(modelName); model != nil { + if model := ml.CheckIsLoaded(modelName); model != nil { return model, nil } @@ -128,7 +128,7 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( return model, nil } -func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client { +func (ml *ModelLoader) CheckIsLoaded(s string) *grpc.Client { if m, ok := ml.models[s]; ok { log.Debug().Msgf("Model already loaded in memory: %s", s)