fix: untangle pkg/grpc and core/schema for Transcription (#3419)

untangle pkg/grpc and core/schema in Transcribe

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-09-02 09:48:53 -04:00 committed by GitHub
parent 1655411ccd
commit c2804c42fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 162 additions and 186 deletions

View file

@ -3,7 +3,6 @@ package grpc
import (
"context"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
)
@ -42,7 +41,7 @@ type Backend interface {
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error)

View file

@ -6,7 +6,6 @@ import (
"fmt"
"os"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
@ -53,8 +52,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return schema.TranscriptionResult{}, fmt.Errorf("unimplemented")
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
}
func (llm *Base) TTS(*pb.TTSRequest) error {

View file

@ -7,7 +7,6 @@ import (
"sync"
"time"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@ -228,7 +227,7 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ
return client.SoundGeneration(ctx, in, opts...)
}
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
@ -243,27 +242,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
defer conn.Close()
client := pb.NewBackendClient(conn)
res, err := client.AudioTranscription(ctx, in, opts...)
if err != nil {
return nil, err
}
tresult := &schema.TranscriptionResult{}
for _, s := range res.Segments {
tks := []int{}
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tresult.Segments = append(tresult.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
tresult.Text = res.Text
return tresult, err
return client.AudioTranscription(ctx, in, opts...)
}
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {

View file

@ -2,9 +2,7 @@ package grpc
import (
"context"
"time"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@ -57,28 +55,8 @@ func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerati
return e.s.SoundGeneration(ctx, in)
}
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
r, err := e.s.AudioTranscription(ctx, in)
if err != nil {
return nil, err
}
tr := &schema.TranscriptionResult{}
for _, s := range r.Segments {
var tks []int
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tr.Segments = append(tr.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
tr.Text = r.Text
return tr, err
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) {
return e.s.AudioTranscription(ctx, in)
}
func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {

View file

@ -1,7 +1,6 @@
package grpc
import (
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
@ -15,7 +14,7 @@ type LLM interface {
Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error)
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
TTS(*pb.TTSRequest) error
SoundGeneration(*pb.SoundGenerationRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)