From afdc0ebfd7c4d7d6ea87359c15a8d302b09ff659 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 19 Aug 2023 01:49:33 +0200 Subject: [PATCH] feat: add --single-active-backend to allow only one backend active at the time (#925) Signed-off-by: Ettore Di Giacinto --- api/backend/embeddings.go | 16 +-- api/backend/image.go | 16 +-- api/backend/llm.go | 16 +-- api/backend/options.go | 22 ++++ api/backend/transcript.go | 16 +-- api/backend/tts.go | 10 +- api/options/options.go | 6 ++ extra/grpc/autogptq/autogptq.py | 2 +- extra/grpc/bark/ttsbark.py | 2 +- extra/grpc/diffusers/backend_diffusers.py | 2 +- extra/grpc/exllama/exllama.py | 2 +- extra/grpc/huggingface/huggingface.py | 2 +- go.mod | 4 +- go.sum | 4 + main.go | 9 ++ pkg/grpc/client.go | 35 +++++++ pkg/model/initializers.go | 105 +++---------------- pkg/model/loader.go | 4 +- pkg/model/options.go | 11 +- pkg/model/process.go | 118 ++++++++++++++++++++++ 20 files changed, 238 insertions(+), 164 deletions(-) create mode 100644 pkg/model/process.go diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index aa1e393f..63f1a831 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -21,25 +21,13 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. var inferenceModel interface{} var err error - opts := []model.Option{ + opts := modelOpts(c, o, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.AssetsDestination), model.WithModel(modelFile), model.WithContext(o.Context), - } - - if c.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) - } - - if c.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) - } - - for k, v := range o.ExternalGRPCBackends { - opts = append(opts, model.WithExternalBackend(k, v)) - } + }) if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) diff --git a/api/backend/image.go b/api/backend/image.go index 9c9ad6c0..ea3f2069 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -9,7 +9,7 @@ import ( func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { - opts := []model.Option{ + opts := modelOpts(c, o, []model.Option{ model.WithBackendString(c.Backend), model.WithAssetDir(o.AssetsDestination), model.WithThreads(uint32(c.Threads)), @@ -25,19 +25,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSkip: int32(c.Diffusers.ClipSkip), }), - } - - if c.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) - } - - if c.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) - } - - for k, v := range o.ExternalGRPCBackends { - opts = append(opts, model.WithExternalBackend(k, v)) - } + }) inferenceModel, err := loader.BackendLoader( opts..., diff --git a/api/backend/llm.go b/api/backend/llm.go index c30e0f81..e40db21a 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -33,25 +33,13 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c var inferenceModel *grpc.Client var err error - opts := []model.Option{ + opts := modelOpts(c, o, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup model.WithAssetDir(o.AssetsDestination), model.WithModel(modelFile), model.WithContext(o.Context), - } - - if c.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) - } - - if c.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) - } - - for k, v := range o.ExternalGRPCBackends { - opts = append(opts, model.WithExternalBackend(k, v)) - } + }) if c.Backend != "" { opts = append(opts, model.WithBackendString(c.Backend)) diff --git a/api/backend/options.go b/api/backend/options.go index 4794fc90..af1081ad 100644 --- a/api/backend/options.go +++ b/api/backend/options.go @@ -5,10 +5,32 @@ import ( "path/filepath" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" ) +func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option { + if o.SingleBackend { + opts = append(opts, model.WithSingleActiveBackend()) + } + + if c.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) + } + + if c.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + return opts +} + func gRPCModelOpts(c config.Config) *pb.ModelOptions { b := 512 if c.Batch != 0 { diff --git a/api/backend/transcript.go b/api/backend/transcript.go index 80a759de..fbc2b7ec 100644 --- a/api/backend/transcript.go +++ b/api/backend/transcript.go @@ -13,24 +13,14 @@ import ( ) func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) { - opts := []model.Option{ + + opts := modelOpts(c, o, []model.Option{ model.WithBackendString(model.WhisperBackend), model.WithModel(c.Model), model.WithContext(o.Context), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.AssetsDestination), - } - - if c.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) - } - - if c.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) - } - for k, v := range o.ExternalGRPCBackends { - opts = append(opts, model.WithExternalBackend(k, v)) - } + }) whisperModel, err := o.Loader.BackendLoader(opts...) if err != nil { diff --git a/api/backend/tts.go b/api/backend/tts.go index 411fc278..a2e56afa 100644 --- a/api/backend/tts.go +++ b/api/backend/tts.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + api_config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" @@ -33,17 +34,12 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt if bb == "" { bb = model.PiperBackend } - opts := []model.Option{ + opts := modelOpts(api_config.Config{}, o, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), model.WithContext(o.Context), model.WithAssetDir(o.AssetsDestination), - } - - for k, v := range o.ExternalGRPCBackends { - opts = append(opts, model.WithExternalBackend(k, v)) - } - + }) piperModel, err := o.Loader.BackendLoader(opts...) if err != nil { return "", nil, err diff --git a/api/options/options.go b/api/options/options.go index ada95d3f..6ffa571c 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -33,6 +33,8 @@ type Option struct { ExternalGRPCBackends map[string]string AutoloadGalleries bool + + SingleBackend bool } type AppOption func(*Option) @@ -58,6 +60,10 @@ func WithCors(b bool) AppOption { } } +var EnableSingleBackend = func(o *Option) { + o.SingleBackend = true +} + var EnableGalleriesAutoload = func(o *Option) { o.AutoloadGalleries = true } diff --git a/extra/grpc/autogptq/autogptq.py b/extra/grpc/autogptq/autogptq.py index 7d8a45fc..6a5f9c7c 100755 --- a/extra/grpc/autogptq/autogptq.py +++ b/extra/grpc/autogptq/autogptq.py @@ -77,7 +77,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/extra/grpc/bark/ttsbark.py b/extra/grpc/bark/ttsbark.py index 63b01798..a14c632d 100644 --- a/extra/grpc/bark/ttsbark.py +++ b/extra/grpc/bark/ttsbark.py @@ -51,7 +51,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(success=True) def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/extra/grpc/diffusers/backend_diffusers.py b/extra/grpc/diffusers/backend_diffusers.py index d93779a4..a005d7f4 100755 --- a/extra/grpc/diffusers/backend_diffusers.py +++ b/extra/grpc/diffusers/backend_diffusers.py @@ -267,7 +267,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(message="Model loaded successfully", success=True) def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/extra/grpc/exllama/exllama.py b/extra/grpc/exllama/exllama.py index e6d979d5..c8eddf4e 100755 --- a/extra/grpc/exllama/exllama.py +++ b/extra/grpc/exllama/exllama.py @@ -110,7 +110,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/extra/grpc/huggingface/huggingface.py b/extra/grpc/huggingface/huggingface.py index 7589dfd6..680c2739 100755 --- a/extra/grpc/huggingface/huggingface.py +++ b/extra/grpc/huggingface/huggingface.py @@ -34,7 +34,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() diff --git a/go.mod b/go.mod index 84339e9b..fb05a4c8 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/json-iterator/go v1.1.12 github.com/mholt/archiver/v3 v3.5.1 github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef - github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d + github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5 github.com/onsi/ginkgo/v2 v2.11.0 @@ -40,7 +40,7 @@ require ( github.com/go-ole/go-ole v1.2.6 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/shirou/gopsutil/v3 v3.23.6 + github.com/shirou/gopsutil/v3 v3.23.7 github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/tklauser/go-sysconf v0.3.11 // indirect github.com/tklauser/numcpus v0.6.0 // indirect diff --git a/go.sum b/go.sum index b085ba67..0e00d744 100644 --- a/go.sum +++ b/go.sum @@ -119,6 +119,8 @@ github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d h1:/lAg9vPAAU+s35cDMCx1IyeMn+4OYfCBPqi08Q8vXDg= github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d/go.mod h1:HGGAOJhipApckwNV8ZTliRJqxctUv3xRY+zbQEwuytc= +github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI= +github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c/go.mod h1:gY3wyrhkRySJtmtI/JPt4a2mKv48h/M9pEZIW+SjeC0= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230815171941-a63093554fb5 h1:b4EeYDaGxOLNlNm5LOVEmrUhaw1v6xq/V79ZwWVlY6I= @@ -164,6 +166,8 @@ github.com/sashabaranov/go-openai v1.14.2 h1:5DPTtR9JBjKPJS008/A409I5ntFhUPPGCma github.com/sashabaranov/go-openai v1.14.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/shirou/gopsutil/v3 v3.23.6 h1:5y46WPI9QBKBbK7EEccUPNXpJpNrvPuTD0O2zHEHT08= github.com/shirou/gopsutil/v3 v3.23.6/go.mod h1:j7QX50DrXYggrpN30W0Mo+I4/8U2UUIQrnrhqUeWrAU= +github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4= +github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= diff --git a/main.go b/main.go index 8f5e6445..4473af6c 100644 --- a/main.go +++ b/main.go @@ -49,6 +49,11 @@ func main() { Name: "debug", EnvVars: []string{"DEBUG"}, }, + &cli.BoolFlag{ + Name: "single-active-backend", + EnvVars: []string{"SINGLE_ACTIVE_BACKEND"}, + Usage: "Allow only one backend to be running.", + }, &cli.BoolFlag{ Name: "cors", EnvVars: []string{"CORS"}, @@ -181,6 +186,10 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithApiKeys(ctx.StringSlice("api-keys")), } + if ctx.Bool("single-active-backend") { + opts = append(opts, options.EnableSingleBackend) + } + externalgRPC := ctx.StringSlice("external-grpc-backends") // split ":" to get backend name and the uri for _, v := range externalgRPC { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index cdc34ad5..d69251ff 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "sync" "time" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" @@ -14,6 +15,8 @@ import ( type Client struct { address string + busy bool + sync.Mutex } func NewClient(address string) *Client { @@ -22,7 +25,21 @@ func NewClient(address string) *Client { } } +func (c *Client) IsBusy() bool { + c.Lock() + defer c.Unlock() + return c.busy +} + +func (c *Client) setBusy(v bool) { + c.Lock() + c.busy = v + c.Unlock() +} + func (c *Client) HealthCheck(ctx context.Context) bool { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { fmt.Println(err) @@ -49,6 +66,8 @@ func (c *Client) HealthCheck(ctx context.Context) bool { } func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -60,6 +79,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ... } func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -71,6 +92,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp } func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -81,6 +104,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp } func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return err @@ -110,6 +135,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun } func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -120,6 +147,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, } func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -130,6 +159,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp } func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -160,6 +191,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques } func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -176,6 +209,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts } func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) { + c.setBusy(true) + defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 14135a98..2c9e0de9 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -4,20 +4,14 @@ import ( "context" "fmt" "os" - "os/signal" "path/filepath" - "strconv" "strings" - "syscall" "time" grpc "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/hashicorp/go-multierror" - "github.com/hpcloud/tail" "github.com/phayes/freeport" "github.com/rs/zerolog/log" - - process "github.com/mudler/go-processmanager" ) const ( @@ -65,89 +59,6 @@ var AutoLoadBackends []string = []string{ PiperBackend, } -func (ml *ModelLoader) GetGRPCPID(id string) (int, error) { - p, exists := ml.grpcProcesses[id] - if !exists { - return -1, fmt.Errorf("no grpc backend found for %s", id) - } - return strconv.Atoi(p.PID) -} - -type GRPCProcessFilter = func(p *process.Process) bool - -func includeAllProcesses(_ *process.Process) bool { - return true -} - -func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) { - for _, p := range ml.grpcProcesses { - if filter(p) { - p.Stop() - } - } -} - -func (ml *ModelLoader) StopAllGRPC() { - ml.StopGRPC(includeAllProcesses) - // for _, p := range ml.grpcProcesses { - // p.Stop() - // } -} - -func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error { - // Make sure the process is executable - if err := os.Chmod(grpcProcess, 0755); err != nil { - return err - } - - log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess) - - log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress) - - grpcControlProcess := process.New( - process.WithTemporaryStateDir(), - process.WithName(grpcProcess), - process.WithArgs("--addr", serverAddress), - process.WithEnvironment(os.Environ()...), - ) - - ml.grpcProcesses[id] = grpcControlProcess - - if err := grpcControlProcess.Run(); err != nil { - return err - } - - log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir()) - // clean up process - go func() { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - grpcControlProcess.Stop() - }() - - go func() { - t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) - if err != nil { - log.Debug().Msgf("Could not tail stderr") - } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) - } - }() - go func() { - t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) - if err != nil { - log.Debug().Msgf("Could not tail stdout") - } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) - } - }() - - return nil -} - // starts the grpcModelProcess for the backend, and returns a grpc client // It also loads the model func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*grpc.Client, error) { @@ -248,6 +159,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er backend := strings.ToLower(o.backendString) + if o.singleActiveBackend { + ml.mu.Lock() + log.Debug().Msgf("Stopping all backends except '%s'", o.model) + ml.StopAllExcept(o.model) + ml.mu.Unlock() + } + // if an external backend is provided, use it _, externalBackendExists := o.externalBackends[backend] if externalBackendExists { @@ -274,14 +192,21 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { o := NewOptions(opts...) - // Is this really needed? BackendLoader already does this ml.mu.Lock() + // Return earlier if we have a model already loaded + // (avoid looping through all the backends) if m := ml.CheckIsLoaded(o.model); m != nil { log.Debug().Msgf("Model '%s' already loaded", o.model) ml.mu.Unlock() return m, nil } + // If we can have only one backend active, kill all the others (except external backends) + if o.singleActiveBackend { + log.Debug().Msgf("Stopping all backends except '%s'", o.model) + ml.StopAllExcept(o.model) + } ml.mu.Unlock() + var err error // autoload also external backends diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 4191cea1..8d129e46 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -137,9 +137,7 @@ func (ml *ModelLoader) CheckIsLoaded(s string) *grpc.Client { if !ml.grpcProcesses[s].IsAlive() { log.Debug().Msgf("GRPC Process is not responding: %s", s) // stop and delete the process, this forces to re-load the model and re-create again the service - ml.grpcProcesses[s].Stop() - delete(ml.grpcProcesses, s) - delete(ml.models, s) + ml.deleteProcess(s) return nil } } diff --git a/pkg/model/options.go b/pkg/model/options.go index 550c50c7..faaf6fb2 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -17,8 +17,9 @@ type Options struct { externalBackends map[string]string - grpcAttempts int - grpcAttemptsDelay int + grpcAttempts int + grpcAttemptsDelay int + singleActiveBackend bool } type Option func(*Options) @@ -80,6 +81,12 @@ func WithContext(ctx context.Context) Option { } } +func WithSingleActiveBackend() Option { + return func(o *Options) { + o.singleActiveBackend = true + } +} + func NewOptions(opts ...Option) *Options { o := &Options{ gRPCOptions: &pb.ModelOptions{}, diff --git a/pkg/model/process.go b/pkg/model/process.go new file mode 100644 index 00000000..156f4195 --- /dev/null +++ b/pkg/model/process.go @@ -0,0 +1,118 @@ +package model + +import ( + "fmt" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + "github.com/hpcloud/tail" + process "github.com/mudler/go-processmanager" + "github.com/rs/zerolog/log" +) + +func (ml *ModelLoader) StopAllExcept(s string) { + ml.StopGRPC(func(id string, p *process.Process) bool { + if id != s { + for ml.models[id].IsBusy() { + log.Debug().Msgf("%s busy. Waiting.", id) + time.Sleep(2 * time.Second) + } + log.Debug().Msgf("[single-backend] Stopping %s", id) + return true + } + return false + }) +} + +func (ml *ModelLoader) deleteProcess(s string) error { + if err := ml.grpcProcesses[s].Stop(); err != nil { + return err + } + delete(ml.grpcProcesses, s) + delete(ml.models, s) + return nil +} + +type GRPCProcessFilter = func(id string, p *process.Process) bool + +func includeAllProcesses(_ string, _ *process.Process) bool { + return true +} + +func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) { + for k, p := range ml.grpcProcesses { + if filter(k, p) { + ml.deleteProcess(k) + } + } +} + +func (ml *ModelLoader) StopAllGRPC() { + ml.StopGRPC(includeAllProcesses) +} + +func (ml *ModelLoader) GetGRPCPID(id string) (int, error) { + p, exists := ml.grpcProcesses[id] + if !exists { + return -1, fmt.Errorf("no grpc backend found for %s", id) + } + return strconv.Atoi(p.PID) +} + +func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error { + // Make sure the process is executable + if err := os.Chmod(grpcProcess, 0755); err != nil { + return err + } + + log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess) + + log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress) + + grpcControlProcess := process.New( + process.WithTemporaryStateDir(), + process.WithName(grpcProcess), + process.WithArgs("--addr", serverAddress), + process.WithEnvironment(os.Environ()...), + ) + + ml.grpcProcesses[id] = grpcControlProcess + + if err := grpcControlProcess.Run(); err != nil { + return err + } + + log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir()) + // clean up process + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + grpcControlProcess.Stop() + }() + + go func() { + t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stderr") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() + go func() { + t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stdout") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() + + return nil +}