mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
fix: drop racy code, refactor and group API schema (#931)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
28db83e17b
commit
cc060a283d
55 changed files with 239 additions and 317 deletions
|
@ -30,6 +30,10 @@ func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, e
|
|||
}
|
||||
|
||||
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
embeds, err := s.llm.Embeddings(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -39,6 +43,10 @@ func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.Embe
|
|||
}
|
||||
|
||||
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.Load(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err
|
||||
|
@ -47,11 +55,19 @@ func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result
|
|||
}
|
||||
|
||||
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
result, err := s.llm.Predict(in)
|
||||
return newReply(result), err
|
||||
}
|
||||
|
||||
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.GenerateImage(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err
|
||||
|
@ -60,6 +76,10 @@ func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest)
|
|||
}
|
||||
|
||||
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.TTS(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
|
||||
|
@ -68,6 +88,10 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error)
|
|||
}
|
||||
|
||||
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
result, err := s.llm.AudioTranscription(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -93,7 +117,10 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
|||
}
|
||||
|
||||
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error {
|
||||
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
resultChan := make(chan string)
|
||||
|
||||
done := make(chan bool)
|
||||
|
@ -111,6 +138,10 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS
|
|||
}
|
||||
|
||||
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.TokenizeString(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue