feat: add --single-active-backend to allow only one backend active at the time (#925)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-08-19 01:49:33 +02:00 committed by GitHub
parent 1079b18ff7
commit afdc0ebfd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 238 additions and 164 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"sync"
"time"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
@ -14,6 +15,8 @@ import (
type Client struct {
address string
busy bool
sync.Mutex
}
func NewClient(address string) *Client {
@ -22,7 +25,21 @@ func NewClient(address string) *Client {
}
}
func (c *Client) IsBusy() bool {
c.Lock()
defer c.Unlock()
return c.busy
}
func (c *Client) setBusy(v bool) {
c.Lock()
c.busy = v
c.Unlock()
}
func (c *Client) HealthCheck(ctx context.Context) bool {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
fmt.Println(err)
@ -49,6 +66,8 @@ func (c *Client) HealthCheck(ctx context.Context) bool {
}
func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -60,6 +79,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
}
func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -71,6 +92,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
}
func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -81,6 +104,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
}
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
@ -110,6 +135,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}
func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -120,6 +147,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
}
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -130,6 +159,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
}
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -160,6 +191,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
@ -176,6 +209,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
}
func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err

View file

@ -4,20 +4,14 @@ import (
"context"
"fmt"
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/hashicorp/go-multierror"
"github.com/hpcloud/tail"
"github.com/phayes/freeport"
"github.com/rs/zerolog/log"
process "github.com/mudler/go-processmanager"
)
const (
@ -65,89 +59,6 @@ var AutoLoadBackends []string = []string{
PiperBackend,
}
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
p, exists := ml.grpcProcesses[id]
if !exists {
return -1, fmt.Errorf("no grpc backend found for %s", id)
}
return strconv.Atoi(p.PID)
}
type GRPCProcessFilter = func(p *process.Process) bool
func includeAllProcesses(_ *process.Process) bool {
return true
}
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) {
for _, p := range ml.grpcProcesses {
if filter(p) {
p.Stop()
}
}
}
func (ml *ModelLoader) StopAllGRPC() {
ml.StopGRPC(includeAllProcesses)
// for _, p := range ml.grpcProcesses {
// p.Stop()
// }
}
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
// Make sure the process is executable
if err := os.Chmod(grpcProcess, 0755); err != nil {
return err
}
log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
grpcControlProcess := process.New(
process.WithTemporaryStateDir(),
process.WithName(grpcProcess),
process.WithArgs("--addr", serverAddress),
process.WithEnvironment(os.Environ()...),
)
ml.grpcProcesses[id] = grpcControlProcess
if err := grpcControlProcess.Run(); err != nil {
return err
}
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
// clean up process
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
grpcControlProcess.Stop()
}()
go func() {
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
if err != nil {
log.Debug().Msgf("Could not tail stderr")
}
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
}
}()
go func() {
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
if err != nil {
log.Debug().Msgf("Could not tail stdout")
}
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
}
}()
return nil
}
// starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*grpc.Client, error) {
@ -248,6 +159,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
backend := strings.ToLower(o.backendString)
if o.singleActiveBackend {
ml.mu.Lock()
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
ml.StopAllExcept(o.model)
ml.mu.Unlock()
}
// if an external backend is provided, use it
_, externalBackendExists := o.externalBackends[backend]
if externalBackendExists {
@ -274,14 +192,21 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
o := NewOptions(opts...)
// Is this really needed? BackendLoader already does this
ml.mu.Lock()
// Return earlier if we have a model already loaded
// (avoid looping through all the backends)
if m := ml.CheckIsLoaded(o.model); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model)
ml.mu.Unlock()
return m, nil
}
// If we can have only one backend active, kill all the others (except external backends)
if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
ml.StopAllExcept(o.model)
}
ml.mu.Unlock()
var err error
// autoload also external backends

View file

@ -137,9 +137,7 @@ func (ml *ModelLoader) CheckIsLoaded(s string) *grpc.Client {
if !ml.grpcProcesses[s].IsAlive() {
log.Debug().Msgf("GRPC Process is not responding: %s", s)
// stop and delete the process, this forces to re-load the model and re-create again the service
ml.grpcProcesses[s].Stop()
delete(ml.grpcProcesses, s)
delete(ml.models, s)
ml.deleteProcess(s)
return nil
}
}

View file

@ -17,8 +17,9 @@ type Options struct {
externalBackends map[string]string
grpcAttempts int
grpcAttemptsDelay int
grpcAttempts int
grpcAttemptsDelay int
singleActiveBackend bool
}
type Option func(*Options)
@ -80,6 +81,12 @@ func WithContext(ctx context.Context) Option {
}
}
func WithSingleActiveBackend() Option {
return func(o *Options) {
o.singleActiveBackend = true
}
}
func NewOptions(opts ...Option) *Options {
o := &Options{
gRPCOptions: &pb.ModelOptions{},

118
pkg/model/process.go Normal file
View file

@ -0,0 +1,118 @@
package model
import (
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/hpcloud/tail"
process "github.com/mudler/go-processmanager"
"github.com/rs/zerolog/log"
)
func (ml *ModelLoader) StopAllExcept(s string) {
ml.StopGRPC(func(id string, p *process.Process) bool {
if id != s {
for ml.models[id].IsBusy() {
log.Debug().Msgf("%s busy. Waiting.", id)
time.Sleep(2 * time.Second)
}
log.Debug().Msgf("[single-backend] Stopping %s", id)
return true
}
return false
})
}
func (ml *ModelLoader) deleteProcess(s string) error {
if err := ml.grpcProcesses[s].Stop(); err != nil {
return err
}
delete(ml.grpcProcesses, s)
delete(ml.models, s)
return nil
}
type GRPCProcessFilter = func(id string, p *process.Process) bool
func includeAllProcesses(_ string, _ *process.Process) bool {
return true
}
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) {
for k, p := range ml.grpcProcesses {
if filter(k, p) {
ml.deleteProcess(k)
}
}
}
func (ml *ModelLoader) StopAllGRPC() {
ml.StopGRPC(includeAllProcesses)
}
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
p, exists := ml.grpcProcesses[id]
if !exists {
return -1, fmt.Errorf("no grpc backend found for %s", id)
}
return strconv.Atoi(p.PID)
}
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
// Make sure the process is executable
if err := os.Chmod(grpcProcess, 0755); err != nil {
return err
}
log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
grpcControlProcess := process.New(
process.WithTemporaryStateDir(),
process.WithName(grpcProcess),
process.WithArgs("--addr", serverAddress),
process.WithEnvironment(os.Environ()...),
)
ml.grpcProcesses[id] = grpcControlProcess
if err := grpcControlProcess.Run(); err != nil {
return err
}
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
// clean up process
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
grpcControlProcess.Stop()
}()
go func() {
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
if err != nil {
log.Debug().Msgf("Could not tail stderr")
}
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
}
}()
go func() {
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
if err != nil {
log.Debug().Msgf("Could not tail stdout")
}
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
}
}()
return nil
}