fix: drop racy code, refactor and group API schema (#931)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-08-20 14:04:45 +02:00 committed by GitHub
parent 28db83e17b
commit cc060a283d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 239 additions and 317 deletions

View file

@ -30,6 +30,10 @@ func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, e
}
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
embeds, err := s.llm.Embeddings(in)
if err != nil {
return nil, err
@ -39,6 +43,10 @@ func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.Embe
}
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.Load(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err
@ -47,11 +55,19 @@ func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result
}
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
result, err := s.llm.Predict(in)
return newReply(result), err
}
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.GenerateImage(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err
@ -60,6 +76,10 @@ func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest)
}
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.TTS(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
@ -68,6 +88,10 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error)
}
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
result, err := s.llm.AudioTranscription(in)
if err != nil {
return nil, err
@ -93,7 +117,10 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
resultChan := make(chan string)
done := make(chan bool)
@ -111,6 +138,10 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS
}
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.TokenizeString(in)
if err != nil {
return nil, err