From ebfe8dd1196d0fa6227b7b8844048112625f8c31 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 18 Nov 2024 19:12:27 +0100 Subject: [PATCH] gRPC client stubs Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 2 +- core/http/endpoints/openai/realtime.go | 17 ++++++++++--- core/http/endpoints/openai/realtime_model.go | 26 ++++++++++++++++++++ pkg/grpc/backend.go | 2 +- 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/backend/backend.proto b/backend/backend.proto index 3137be09..162fb595 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -159,7 +159,7 @@ message Reply { bytes message = 1; int32 tokens = 2; int32 prompt_tokens = 3; - string audio_output = 4; + bytes audio = 5; } message ModelOptions { diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index c841a3e4..43f268cf 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -120,6 +120,8 @@ var sessionLock sync.Mutex // TODO: implement interface as we start to define usages type Model interface { VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) + Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) + PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error } func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) { @@ -800,7 +802,17 @@ func processAudioResponse(session *Session, audioData []byte) (string, []byte, * // 4. Convert the response text to speech (audio) // // Placeholder implementation: - // TODO: use session.ModelInterface... + + // TODO: template eventual messages, like chat.go + reply, err := session.ModelInterface.Predict(context.Background(), &proto.PredictOptions{ + Prompt: "What's the weather in New York?", + }) + + if err != nil { + return "", nil, nil, err + } + + generatedAudio := reply.Audio transcribedText := "What's the weather in New York?" var functionCall *FunctionCall @@ -819,9 +831,6 @@ func processAudioResponse(session *Session, audioData []byte) (string, []byte, * // Generate a response generatedText := "This is a response to your speech input." - generatedAudio := []byte{} // Generate audio bytes from the generatedText - - // TODO: Implement actual transcription and TTS return generatedText, generatedAudio, nil, nil } diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index a32f8c10..20b77862 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -13,6 +13,11 @@ import ( "google.golang.org/grpc" ) +var ( + _ Model = new(wrappedModel) + _ Model = new(anyToAnyModel) +) + // wrappedModel represent a model which does not support Any-to-Any operations // This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods // which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS) @@ -47,6 +52,27 @@ func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...g return m.VADClient.VAD(ctx, in) } +func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { + // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it) + // sound.BufferAsWAV(audioData, "audio.wav") + + return m.LLMClient.Predict(ctx, in) +} + +func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { + // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it) + + return m.LLMClient.PredictStream(ctx, in, f) +} + +func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { + return m.LLMClient.Predict(ctx, in) +} + +func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { + return m.LLMClient.PredictStream(ctx, in, f) +} + // returns and loads either a wrapped model or a model that support audio-to-audio func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) { diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index fabc0268..9b82a62e 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -35,9 +35,9 @@ type Backend interface { IsBusy() bool HealthCheck(ctx context.Context) (bool, error) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) - Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error + Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)