diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index 5f102c43..554cb11e 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -30,6 +30,14 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. model.WithContext(o.Context), } + if c.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) + } + + if c.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) + } + for k, v := range o.ExternalGRPCBackends { opts = append(opts, model.WithExternalBackend(k, v)) } diff --git a/api/backend/image.go b/api/backend/image.go index a8e04167..6126aff6 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -24,6 +24,14 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat }), } + if c.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) + } + + if c.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) + } + for k, v := range o.ExternalGRPCBackends { opts = append(opts, model.WithExternalBackend(k, v)) } diff --git a/api/backend/llm.go b/api/backend/llm.go index a98bfb21..80067e7c 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -31,6 +31,14 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c model.WithContext(o.Context), } + if c.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) + } + + if c.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) + } + for k, v := range o.ExternalGRPCBackends { opts = append(opts, model.WithExternalBackend(k, v)) } diff --git a/api/backend/transcript.go b/api/backend/transcript.go index add16f4c..80a759de 100644 --- a/api/backend/transcript.go +++ b/api/backend/transcript.go @@ -21,6 +21,13 @@ func ModelTranscription(audio, language string, loader *model.ModelLoader, c con model.WithAssetDir(o.AssetsDestination), } + if c.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) + } + + if c.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) + } for k, v := range o.ExternalGRPCBackends { opts = append(opts, model.WithExternalBackend(k, v)) } diff --git a/api/config/config.go b/api/config/config.go index d877a87c..52bb427b 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -39,6 +39,14 @@ type Config struct { Diffusers Diffusers `yaml:"diffusers"` Step int `yaml:"step"` + + // GRPC Options + GRPC GRPC `yaml:"grpc"` +} + +type GRPC struct { + Attempts int `yaml:"attempts"` + AttemptsSleepTime int `yaml:"attempts_sleep_time"` } type Diffusers struct { diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 7b808c25..49c472f7 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -185,13 +185,13 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string // Wait for the service to start up ready := false - for i := 0; i < 10; i++ { + for i := 0; i < o.grpcAttempts; i++ { if client.HealthCheck(context.Background()) { log.Debug().Msgf("GRPC Service Ready") ready = true break } - time.Sleep(1 * time.Second) + time.Sleep(time.Duration(o.grpcAttemptsDelay) * time.Second) } if !ready { diff --git a/pkg/model/options.go b/pkg/model/options.go index 83c14b65..550c50c7 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -16,6 +16,9 @@ type Options struct { gRPCOptions *pb.ModelOptions externalBackends map[string]string + + grpcAttempts int + grpcAttemptsDelay int } type Option func(*Options) @@ -29,6 +32,18 @@ func WithExternalBackend(name string, uri string) Option { } } +func WithGRPCAttempts(attempts int) Option { + return func(o *Options) { + o.grpcAttempts = attempts + } +} + +func WithGRPCAttemptsDelay(delay int) Option { + return func(o *Options) { + o.grpcAttemptsDelay = delay + } +} + func WithBackendString(backend string) Option { return func(o *Options) { o.backendString = backend @@ -67,8 +82,10 @@ func WithContext(ctx context.Context) Option { func NewOptions(opts ...Option) *Options { o := &Options{ - gRPCOptions: &pb.ModelOptions{}, - context: context.Background(), + gRPCOptions: &pb.ModelOptions{}, + context: context.Background(), + grpcAttempts: 20, + grpcAttemptsDelay: 2, } for _, opt := range opts { opt(o)