fix(model-loading): keep track of open GRPC Clients (#3377)

Due to a previous refactor we moved the client constructor tight to the
model address, however that was just a string which we would use to
build the client each time.

With this change we make the loader to return a *Model which carries a
constructor for the client and stores the client on the first
connection.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-08-25 14:36:09 +02:00 committed by GitHub
parent 771a052480
commit 7f06954425
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 176 additions and 171 deletions

View file

@ -39,6 +39,18 @@ func (c *Client) setBusy(v bool) {
c.Unlock()
}
func (c *Client) wdMark() {
if c.wd != nil {
c.wd.Mark(c.address)
}
}
func (c *Client) wdUnMark() {
if c.wd != nil {
c.wd.UnMark(c.address)
}
}
func (c *Client) HealthCheck(ctx context.Context) (bool, error) {
if !c.parallel {
c.opMutex.Lock()
@ -76,10 +88,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -97,10 +107,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -118,10 +126,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -138,10 +144,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
@ -177,10 +181,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -197,10 +199,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -217,10 +217,8 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -237,10 +235,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -277,10 +273,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
}
c.setBusy(true)
defer c.setBusy(false)
if c.wd != nil {
c.wd.Mark(c.address)
defer c.wd.UnMark(c.address)
}
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -319,6 +313,8 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ..
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -333,6 +329,8 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.wdMark()
defer c.wdUnMark()
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
@ -351,6 +349,8 @@ func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ..
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -367,6 +367,8 @@ func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -383,6 +385,8 @@ func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err