diff --git a/Dockerfile b/Dockerfile index c6c426a7..b4cabc59 100644 --- a/Dockerfile +++ b/Dockerfile @@ -285,20 +285,40 @@ EOT ################################### ################################### -# The builder target compiles LocalAI. This target is not the target that will be uploaded to the registry. -# Adjustments to the build process should likely be made here. -FROM builder-base AS builder +# Compile backends first in a separate stage +FROM builder-base AS builder-backends -# Install the pre-built GRPC COPY --from=grpc /opt/grpc /usr/local -# Rebuild with defaults backends +WORKDIR /build + +COPY ./Makefile . +COPY ./backend ./backend +COPY ./go.mod . +COPY ./go.sum . +COPY ./.git ./.git + +# Some of the Go backends use libs from the main src, we could further optimize the caching by building the CPP backends before here +COPY ./pkg/grpc ./pkg/grpc +COPY ./pkg/utils ./pkg/utils +COPY ./pkg/langchain ./pkg/langchain + +RUN ls -l ./ +RUN make backend-assets +RUN make prepare +RUN if [ "${TARGETARCH}" = "arm64" ] || [ "${BUILD_TYPE}" = "hipblas" ]; then \ + SKIP_GRPC_BACKEND="backend-assets/grpc/llama-cpp-avx512 backend-assets/grpc/llama-cpp-avx backend-assets/grpc/llama-cpp-avx2" make grpcs; \ + else \ + make grpcs; \ + fi + +# The builder target compiles LocalAI. This target is not the target that will be uploaded to the registry. +# Adjustments to the build process should likely be made here. +FROM builder-backends AS builder + WORKDIR /build COPY . . -COPY .git . - -RUN make prepare ## Build the binary ## If we're on arm64 AND using cublas/hipblas, skip some of the llama-compat backends to save space @@ -390,8 +410,6 @@ COPY . . COPY --from=builder /build/sources ./sources/ COPY --from=grpc /opt/grpc /usr/local -RUN make prepare-sources - # Copy the binary COPY --from=builder /build/local-ai ./ diff --git a/Makefile b/Makefile index 2e0c886d..460d000b 100644 --- a/Makefile +++ b/Makefile @@ -842,10 +842,9 @@ docker-aio-all: docker-image-intel: docker build \ - --progress plain \ --build-arg BASE_IMAGE=intel/oneapi-basekit:2025.1.0-0-devel-ubuntu24.04 \ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \ - --build-arg GO_TAGS="none" \ + --build-arg GO_TAGS="$(GO_TAGS)" \ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \ --build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) . @@ -853,7 +852,7 @@ docker-image-intel-xpu: docker build \ --build-arg BASE_IMAGE=intel/oneapi-basekit:2025.1.0-0-devel-ubuntu22.04 \ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \ - --build-arg GO_TAGS="none" \ + --build-arg GO_TAGS="$(GO_TAGS)" \ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \ --build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) . diff --git a/backend/backend.proto b/backend/backend.proto index 9021a353..47ee7c6e 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -162,6 +162,7 @@ message Reply { int32 prompt_tokens = 3; double timing_prompt_processing = 4; double timing_token_generation = 5; + bytes audio = 6; } message GrammarTrigger { diff --git a/backend/go/vad/silero/vad.go b/backend/go/vad/silero/vad.go index 5a164d2a..f3e9f7be 100644 --- a/backend/go/vad/silero/vad.go +++ b/backend/go/vad/silero/vad.go @@ -21,8 +21,8 @@ func (vad *VAD) Load(opts *pb.ModelOptions) error { SampleRate: 16000, //WindowSize: 1024, Threshold: 0.5, - MinSilenceDurationMs: 0, - SpeechPadMs: 0, + MinSilenceDurationMs: 100, + SpeechPadMs: 30, }) if err != nil { return fmt.Errorf("create silero detector: %w", err) @@ -35,6 +35,10 @@ func (vad *VAD) Load(opts *pb.ModelOptions) error { func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) { audio := req.Audio + if err := vad.detector.Reset(); err != nil { + return pb.VADResponse{}, fmt.Errorf("reset: %w", err) + } + segments, err := vad.detector.Detect(audio) if err != nil { return pb.VADResponse{}, fmt.Errorf("detect: %w", err) diff --git a/core/backend/llm.go b/core/backend/llm.go index 57e2ae35..f36a568a 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -22,8 +22,9 @@ import ( ) type LLMResponse struct { - Response string // should this be []byte? - Usage TokenUsage + Response string // should this be []byte? + Usage TokenUsage + AudioOutput string } type TokenUsage struct { diff --git a/core/config/backend_config.go b/core/config/backend_config.go index ec0f2812..0be552e6 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -37,6 +37,7 @@ type BackendConfig struct { TemplateConfig TemplateConfig `yaml:"template"` KnownUsecaseStrings []string `yaml:"known_usecases"` KnownUsecases *BackendConfigUsecases `yaml:"-"` + Pipeline Pipeline `yaml:"pipeline"` PromptStrings, InputStrings []string `yaml:"-"` InputToken [][]int `yaml:"-"` @@ -72,6 +73,14 @@ type BackendConfig struct { Options []string `yaml:"options"` } +// Pipeline defines other models to use for audio-to-audio +type Pipeline struct { + TTS string `yaml:"tts"` + LLM string `yaml:"llm"` + Transcription string `yaml:"transcription"` + VAD string `yaml:"vad"` +} + type File struct { Filename string `yaml:"filename" json:"filename"` SHA256 string `yaml:"sha256" json:"sha256"` diff --git a/core/http/app.go b/core/http/app.go index 0edd7ef1..bce27397 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -9,6 +9,7 @@ import ( "path/filepath" "github.com/dave-gray101/v2keyauth" + "github.com/gofiber/websocket/v2" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/core/http/endpoints/localai" @@ -99,6 +100,15 @@ func API(application *application.Application) (*fiber.App, error) { }) } + router.Use("/v1/realtime", func(c *fiber.Ctx) error { + if websocket.IsWebSocketUpgrade(c) { + // Returns true if the client requested upgrade to the WebSocket protocol + return c.Next() + } + + return nil + }) + router.Hooks().OnListen(func(listenData fiber.ListenData) error { scheme := "http" if listenData.TLS { diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go new file mode 100644 index 00000000..83d77b0c --- /dev/null +++ b/core/http/endpoints/openai/realtime.go @@ -0,0 +1,1265 @@ +package openai + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/go-audio/audio" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/websocket/v2" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + laudio "github.com/mudler/LocalAI/pkg/audio" + "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/sound" + "github.com/mudler/LocalAI/pkg/templates" + + "google.golang.org/grpc" + + "github.com/rs/zerolog/log" +) + +const ( + localSampleRate = 16000 + remoteSampleRate = 24000 +) + +// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result +// If the model support instead audio-to-audio, we will use the specific gRPC calls instead + +// Session represents a single WebSocket connection and its state +type Session struct { + ID string + TranscriptionOnly bool + Model string + Voice string + TurnDetection *types.ServerTurnDetection `json:"turn_detection"` // "server_vad" or "none" + InputAudioTranscription *types.InputAudioTranscription + Functions functions.Functions + Conversations map[string]*Conversation + InputAudioBuffer []byte + AudioBufferLock sync.Mutex + Instructions string + DefaultConversationID string + ModelInterface Model +} + +func (s *Session) FromClient(session *types.ClientSession) { +} + +func (s *Session) ToServer() types.ServerSession { + return types.ServerSession{ + ID: s.ID, + Object: func() string { + if s.TranscriptionOnly { + return "realtime.transcription_session" + } else { + return "realtime.session" + } + }(), + Model: s.Model, + Modalities: []types.Modality{types.ModalityText, types.ModalityAudio}, + Instructions: s.Instructions, + Voice: s.Voice, + InputAudioFormat: types.AudioFormatPcm16, + OutputAudioFormat: types.AudioFormatPcm16, + TurnDetection: s.TurnDetection, + InputAudioTranscription: s.InputAudioTranscription, + // TODO: Should be constructed from Functions? + Tools: []types.Tool{}, + // TODO: ToolChoice + // TODO: Temperature + // TODO: MaxOutputTokens + // TODO: InputAudioNoiseReduction + } +} + +// TODO: Update to tools? +// FunctionCall represents a function call initiated by the model +type FunctionCall struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// Conversation represents a conversation with a list of items +type Conversation struct { + ID string + Items []*types.MessageItem + Lock sync.Mutex +} + +func (c *Conversation) ToServer() types.Conversation { + return types.Conversation{ + ID: c.ID, + Object: "realtime.conversation", + } +} + +// Item represents a message, function_call, or function_call_output +type Item struct { + ID string `json:"id"` + Object string `json:"object"` + Type string `json:"type"` // "message", "function_call", "function_call_output" + Status string `json:"status"` + Role string `json:"role"` + Content []ConversationContent `json:"content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +// ConversationContent represents the content of an item +type ConversationContent struct { + Type string `json:"type"` // "input_text", "input_audio", "text", "audio", etc. + Audio string `json:"audio,omitempty"` + Text string `json:"text,omitempty"` + // Additional fields as needed +} + +// Define the structures for incoming messages +type IncomingMessage struct { + Type types.ClientEventType `json:"type"` + Session json.RawMessage `json:"session,omitempty"` + Item json.RawMessage `json:"item,omitempty"` + Audio string `json:"audio,omitempty"` + Response json.RawMessage `json:"response,omitempty"` + Error *ErrorMessage `json:"error,omitempty"` + // Other fields as needed +} + +// ErrorMessage represents an error message sent to the client +type ErrorMessage struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` + Param string `json:"param,omitempty"` + EventID string `json:"event_id,omitempty"` +} + +// Define a structure for outgoing messages +type OutgoingMessage struct { + Type string `json:"type"` + Session *Session `json:"session,omitempty"` + Conversation *Conversation `json:"conversation,omitempty"` + Item *Item `json:"item,omitempty"` + Content string `json:"content,omitempty"` + Audio string `json:"audio,omitempty"` + Error *ErrorMessage `json:"error,omitempty"` +} + +// Map to store sessions (in-memory) +var sessions = make(map[string]*Session) +var sessionLock sync.Mutex + +// TODO: implement interface as we start to define usages +type Model interface { + VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) + Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) + Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) + PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error +} + +// TODO: Implement ephemeral keys to allow these endpoints to be used +func RealtimeSessions(application *application.Application) fiber.Handler { + return func(ctx *fiber.Ctx) error { + return ctx.SendStatus(501) + } +} + +func RealtimeTranscriptionSession(application *application.Application) fiber.Handler { + return func(ctx *fiber.Ctx) error { + return ctx.SendStatus(501) + } +} + +func Realtime(application *application.Application) fiber.Handler { + return websocket.New(registerRealtime(application)) +} + +func registerRealtime(application *application.Application) func(c *websocket.Conn) { + return func(c *websocket.Conn) { + + evaluator := application.TemplatesEvaluator() + log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String()) + + model := c.Query("model", "gpt-4o") + + intent := c.Query("intent") + if intent != "transcription" { + sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter") + } + + log.Debug().Msgf("Realtime params: model=%s, intent=%s", model, intent) + + sessionID := generateSessionID() + session := &Session{ + ID: sessionID, + TranscriptionOnly: true, + Model: model, // default model + Voice: "alloy", // default voice + TurnDetection: &types.ServerTurnDetection{ + Type: types.ServerTurnDetectionTypeServerVad, + TurnDetectionParams: types.TurnDetectionParams{ + // TODO: Need some way to pass this to the backend + Threshold: 0.5, + // TODO: This is ignored and the amount of padding is random at present + PrefixPaddingMs: 30, + SilenceDurationMs: 500, + CreateResponse: func() *bool { t := true; return &t }(), + }, + }, + InputAudioTranscription: &types.InputAudioTranscription{ + Model: "whisper-1", + }, + Conversations: make(map[string]*Conversation), + } + + // Create a default conversation + conversationID := generateConversationID() + conversation := &Conversation{ + ID: conversationID, + Items: []*types.MessageItem{}, + } + session.Conversations[conversationID] = conversation + session.DefaultConversationID = conversationID + + // TODO: The API has no way to configure the VAD model or other models that make up a pipeline to fake any-to-any + // So possibly we could have a way to configure a composite model that can be used in situations where any-to-any is expected + pipeline := config.Pipeline{ + VAD: "silero-vad", + Transcription: session.InputAudioTranscription.Model, + } + + m, cfg, err := newTranscriptionOnlyModel( + &pipeline, + application.BackendLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + ) + if err != nil { + log.Error().Msgf("failed to load model: %s", err.Error()) + sendError(c, "model_load_error", "Failed to load model", "", "") + return + } + session.ModelInterface = m + + // Store the session + sessionLock.Lock() + sessions[sessionID] = session + sessionLock.Unlock() + + // Send session.created and conversation.created events to the client + sendEvent(c, types.SessionCreatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + Type: types.ServerEventTypeSessionCreated, + }, + Session: session.ToServer(), + }) + sendEvent(c, types.ConversationCreatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + Type: types.ServerEventTypeConversationCreated, + }, + Conversation: conversation.ToServer(), + }) + + var ( + // mt int + msg []byte + wg sync.WaitGroup + done = make(chan struct{}) + ) + + vadServerStarted := true + wg.Add(1) + go func() { + defer wg.Done() + conversation := session.Conversations[session.DefaultConversationID] + handleVAD(cfg, evaluator, session, conversation, c, done) + }() + + for { + if _, msg, err = c.ReadMessage(); err != nil { + log.Error().Msgf("read: %s", err.Error()) + break + } + + // Parse the incoming message + var incomingMsg IncomingMessage + if err := json.Unmarshal(msg, &incomingMsg); err != nil { + log.Error().Msgf("invalid json: %s", err.Error()) + sendError(c, "invalid_json", "Invalid JSON format", "", "") + continue + } + + var sessionUpdate types.ClientSession + switch incomingMsg.Type { + case types.ClientEventTypeTranscriptionSessionUpdate: + log.Debug().Msgf("recv: %s", msg) + + if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil { + log.Error().Msgf("failed to unmarshal 'transcription_session.update': %s", err.Error()) + sendError(c, "invalid_session_update", "Invalid session update format", "", "") + continue + } + if err := updateTransSession( + session, + &sessionUpdate, + application.BackendLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + ); err != nil { + log.Error().Msgf("failed to update session: %s", err.Error()) + sendError(c, "session_update_error", "Failed to update session", "", "") + continue + } + + sendEvent(c, types.SessionUpdatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + Type: types.ServerEventTypeTranscriptionSessionUpdated, + }, + Session: session.ToServer(), + }) + + case types.ClientEventTypeSessionUpdate: + log.Debug().Msgf("recv: %s", msg) + + // Update session configurations + if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil { + log.Error().Msgf("failed to unmarshal 'session.update': %s", err.Error()) + sendError(c, "invalid_session_update", "Invalid session update format", "", "") + continue + } + if err := updateSession( + session, + &sessionUpdate, + application.BackendLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + ); err != nil { + log.Error().Msgf("failed to update session: %s", err.Error()) + sendError(c, "session_update_error", "Failed to update session", "", "") + continue + } + + sendEvent(c, types.SessionUpdatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + Type: types.ServerEventTypeSessionUpdated, + }, + Session: session.ToServer(), + }) + + if session.TurnDetection.Type == types.ServerTurnDetectionTypeServerVad && !vadServerStarted { + log.Debug().Msg("Starting VAD goroutine...") + wg.Add(1) + go func() { + defer wg.Done() + conversation := session.Conversations[session.DefaultConversationID] + handleVAD(cfg, evaluator, session, conversation, c, done) + }() + vadServerStarted = true + } else if session.TurnDetection.Type != types.ServerTurnDetectionTypeServerVad && vadServerStarted { + log.Debug().Msg("Stopping VAD goroutine...") + + wg.Add(-1) + go func() { + done <- struct{}{} + }() + vadServerStarted = false + } + case types.ClientEventTypeInputAudioBufferAppend: + // Handle 'input_audio_buffer.append' + if incomingMsg.Audio == "" { + log.Error().Msg("Audio data is missing in 'input_audio_buffer.append'") + sendError(c, "missing_audio_data", "Audio data is missing", "", "") + continue + } + + // Decode base64 audio data + decodedAudio, err := base64.StdEncoding.DecodeString(incomingMsg.Audio) + if err != nil { + log.Error().Msgf("failed to decode audio data: %s", err.Error()) + sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "") + continue + } + + // Append to InputAudioBuffer + session.AudioBufferLock.Lock() + session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) + session.AudioBufferLock.Unlock() + + case types.ClientEventTypeInputAudioBufferCommit: + log.Debug().Msgf("recv: %s", msg) + + // TODO: Trigger transcription. + // TODO: Ignore this if VAD enabled or interrupt VAD? + + if session.TranscriptionOnly { + continue + } + + // Commit the audio buffer to the conversation as a new item + item := &types.MessageItem{ + ID: generateItemID(), + Type: "message", + Status: "completed", + Role: "user", + Content: []types.MessageContentPart{ + { + Type: "input_audio", + Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer), + }, + }, + } + + // Add item to conversation + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, item) + conversation.Lock.Unlock() + + // Reset InputAudioBuffer + session.AudioBufferLock.Lock() + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() + + // Send item.created event + sendEvent(c, types.ConversationItemCreatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + Type: "conversation.item.created", + }, + Item: types.ResponseMessageItem{ + Object: "realtime.item", + MessageItem: *item, + }, + }) + + case types.ClientEventTypeConversationItemCreate: + log.Debug().Msgf("recv: %s", msg) + + // Handle creating new conversation items + var item types.ConversationItemCreateEvent + if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { + log.Error().Msgf("failed to unmarshal 'conversation.item.create': %s", err.Error()) + sendError(c, "invalid_item", "Invalid item format", "", "") + continue + } + + sendNotImplemented(c, "conversation.item.create") + + // Generate item ID and set status + // item.ID = generateItemID() + // item.Object = "realtime.item" + // item.Status = "completed" + // + // // Add item to conversation + // conversation.Lock.Lock() + // conversation.Items = append(conversation.Items, &item) + // conversation.Lock.Unlock() + // + // // Send item.created event + // sendEvent(c, OutgoingMessage{ + // Type: "conversation.item.created", + // Item: &item, + // }) + + case types.ClientEventTypeConversationItemDelete: + sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO") + + case types.ClientEventTypeResponseCreate: + // Handle generating a response + var responseCreate types.ResponseCreateEvent + if len(incomingMsg.Response) > 0 { + if err := json.Unmarshal(incomingMsg.Response, &responseCreate); err != nil { + log.Error().Msgf("failed to unmarshal 'response.create' response object: %s", err.Error()) + sendError(c, "invalid_response_create", "Invalid response create format", "", "") + continue + } + } + + // Update session functions if provided + if len(responseCreate.Response.Tools) > 0 { + // TODO: Tools -> Functions + } + + sendNotImplemented(c, "response.create") + + // TODO: Generate a response based on the conversation history + // wg.Add(1) + // go func() { + // defer wg.Done() + // generateResponse(cfg, evaluator, session, conversation, responseCreate, c, mt) + // }() + + case types.ClientEventTypeResponseCancel: + log.Printf("recv: %s", msg) + + // Handle cancellation of ongoing responses + // Implement cancellation logic as needed + sendNotImplemented(c, "response.cancel") + + default: + log.Error().Msgf("unknown message type: %s", incomingMsg.Type) + sendError(c, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") + } + } + + // Close the done channel to signal goroutines to exit + close(done) + wg.Wait() + + // Remove the session from the sessions map + sessionLock.Lock() + delete(sessions, sessionID) + sessionLock.Unlock() + } +} + +// Helper function to send events to the client +func sendEvent(c *websocket.Conn, event types.ServerEvent) { + eventBytes, err := json.Marshal(event) + if err != nil { + log.Error().Msgf("failed to marshal event: %s", err.Error()) + return + } + if err = c.WriteMessage(websocket.TextMessage, eventBytes); err != nil { + log.Error().Msgf("write: %s", err.Error()) + } +} + +// Helper function to send errors to the client +func sendError(c *websocket.Conn, code, message, param, eventID string) { + errorEvent := types.ErrorEvent{ + ServerEventBase: types.ServerEventBase{ + Type: types.ServerEventTypeError, + EventID: eventID, + }, + Error: types.Error{ + Type: "invalid_request_error", + Code: code, + Message: message, + EventID: eventID, + }, + } + + sendEvent(c, errorEvent) +} + +func sendNotImplemented(c *websocket.Conn, message string) { + sendError(c, "not_implemented", message, "", "event_TODO") +} + +func updateTransSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { + sessionLock.Lock() + defer sessionLock.Unlock() + + trUpd := update.InputAudioTranscription + trCur := session.InputAudioTranscription + + if trUpd != nil && trUpd.Model != "" && trUpd.Model != trCur.Model { + pipeline := config.Pipeline { + VAD: "silero-vad", + Transcription: session.InputAudioTranscription.Model, + } + + m, _, err := newTranscriptionOnlyModel(&pipeline, cl, ml, appConfig) + if err != nil { + return err + } + + session.ModelInterface = m + } + + if update.TurnDetection != nil && update.TurnDetection.Type != "" { + session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type) + session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams + } + + return nil +} + +// Function to update session configurations +func updateSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { + sessionLock.Lock() + defer sessionLock.Unlock() + + if update.Model != "" { + pipeline := config.Pipeline{ + LLM: update.Model, + // TODO: Setup pipeline by configuring STT and TTS models + } + m, err := newModel(&pipeline, cl, ml, appConfig) + if err != nil { + return err + } + session.ModelInterface = m + session.Model = update.Model + } + + if update.Voice != "" { + session.Voice = update.Voice + } + if update.TurnDetection != nil && update.TurnDetection.Type != "" { + session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type) + session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams + } + // TODO: We should actually check if the field was present in the JSON; empty string means clear the settings + if update.Instructions != "" { + session.Instructions = update.Instructions + } + if update.Tools != nil { + return fmt.Errorf("Haven't implemented tools") + } + + session.InputAudioTranscription = update.InputAudioTranscription + + return nil +} + +// handleVAD is a goroutine that listens for audio data from the client, +// runs VAD on the audio data, and commits utterances to the conversation +func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) { + vadContext, cancel := context.WithCancel(context.Background()) + go func() { + <-done + cancel() + }() + + silenceThreshold := float64(session.TurnDetection.SilenceDurationMs) / 1000 + + ticker := time.NewTicker(300 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-done: + return + case <-ticker.C: + session.AudioBufferLock.Lock() + allAudio := make([]byte, len(session.InputAudioBuffer)) + copy(allAudio, session.InputAudioBuffer) + session.AudioBufferLock.Unlock() + + aints := sound.BytesToInt16sLE(allAudio) + if len(aints) == 0 || len(aints) < int(silenceThreshold)*remoteSampleRate { + continue + } + + // Resample from 24kHz to 16kHz + aints = sound.ResampleInt16(aints, remoteSampleRate, localSampleRate) + + segments, err := runVAD(vadContext, session, aints) + if err != nil { + if err.Error() == "unexpected speech end" { + log.Debug().Msg("VAD cancelled") + continue + } + log.Error().Msgf("failed to process audio: %s", err.Error()) + sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") + continue + } + + audioLength := float64(len(aints)) / localSampleRate + + // TODO: When resetting the buffer we should retain a small postfix + // TODO: The OpenAI documentation seems to suggest that only the client decides when to clear the buffer + if len(segments) == 0 && audioLength > silenceThreshold { + session.AudioBufferLock.Lock() + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() + log.Debug().Msgf("Detected silence for a while, clearing audio buffer") + + sendEvent(c, types.InputAudioBufferClearedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + Type: types.ServerEventTypeInputAudioBufferCleared, + }, + }) + + continue + } else if len(segments) == 0 { + continue + } + + // TODO: Send input_audio_buffer.speech_started and input_audio_buffer.speech_stopped + + // Segment still in progress when audio ended + segEndTime := segments[len(segments)-1].GetEnd() + if segEndTime == 0 { + continue + } + + if float32(audioLength)-segEndTime > float32(silenceThreshold) { + log.Debug().Msgf("Detected end of speech segment") + session.AudioBufferLock.Lock() + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() + + sendEvent(c, types.InputAudioBufferCommittedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + Type: types.ServerEventTypeInputAudioBufferCommitted, + }, + ItemID: generateItemID(), + PreviousItemID: "TODO", + }) + + abytes := sound.Int16toBytesLE(aints) + // TODO: Remove prefix silence that is is over TurnDetectionParams.PrefixPaddingMs + go commitUtterance(vadContext, abytes, cfg, evaluator, session, conv, c) + } + } + } +} + +func commitUtterance(ctx context.Context, utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) { + if len(utt) == 0 { + return + } + + // TODO: If we have a real any-to-any model then transcription is optional + + f, err := os.CreateTemp("", "realtime-audio-chunk-*.wav") + if err != nil { + log.Error().Msgf("failed to create temp file: %s", err.Error()) + return + } + defer f.Close() + defer os.Remove(f.Name()) + log.Debug().Msgf("Writing to %s\n", f.Name()) + + hdr := laudio.NewWAVHeader(uint32(len(utt))) + if err := hdr.Write(f); err != nil { + log.Error().Msgf("Failed to write WAV header: %s", err.Error()) + return + } + + if _, err := f.Write(utt); err != nil { + log.Error().Msgf("Failed to write audio data: %s", err.Error()) + return + } + + f.Sync() + + if session.InputAudioTranscription != nil { + tr, err := session.ModelInterface.Transcribe(ctx, &proto.TranscriptRequest{ + Dst: f.Name(), + Language: session.InputAudioTranscription.Language, + Translate: false, + Threads: uint32(*cfg.Threads), + }) + if err != nil { + sendError(c, "transcription_failed", err.Error(), "", "event_TODO") + } + + sendEvent(c, types.ResponseAudioTranscriptDoneEvent{ + ServerEventBase: types.ServerEventBase{ + Type: types.ServerEventTypeResponseAudioTranscriptDone, + EventID: "event_TODO", + }, + + ItemID: generateItemID(), + ResponseID: "resp_TODO", + OutputIndex: 0, + ContentIndex: 0, + Transcript: tr.GetText(), + }) + // TODO: Update the prompt with transcription result? + } + + if !session.TranscriptionOnly { + sendNotImplemented(c, "Commiting items to the conversation not implemented") + } + + // TODO: Commit the audio and/or transcribed text to the conversation + // Commit logic: create item, broadcast item.created, etc. + // item := &Item{ + // ID: generateItemID(), + // Object: "realtime.item", + // Type: "message", + // Status: "completed", + // Role: "user", + // Content: []ConversationContent{ + // { + // Type: "input_audio", + // Audio: base64.StdEncoding.EncodeToString(utt), + // }, + // }, + // } + // conv.Lock.Lock() + // conv.Items = append(conv.Items, item) + // conv.Lock.Unlock() + // + // + // sendEvent(c, OutgoingMessage{ + // Type: "conversation.item.created", + // Item: item, + // }) + // + // + // // trigger the response generation + // generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage) +} + +func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADSegment, error) { + soundIntBuffer := &audio.IntBuffer{ + Format: &audio.Format{SampleRate: localSampleRate, NumChannels: 1}, + SourceBitDepth: 16, + Data: sound.ConvertInt16ToInt(adata), + } + + float32Data := soundIntBuffer.AsFloat32Buffer().Data + + resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{ + Audio: float32Data, + }) + if err != nil { + return nil, err + } + + // If resp.Segments is empty => no speech + return resp.Segments, nil +} + +// TODO: Below needed for normal mode instead of transcription only +// Function to generate a response based on the conversation +// func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) { +// +// log.Debug().Msg("Generating realtime response...") +// +// // Compile the conversation history +// conversation.Lock.Lock() +// var conversationHistory []schema.Message +// var latestUserAudio string +// for _, item := range conversation.Items { +// for _, content := range item.Content { +// switch content.Type { +// case "input_text", "text": +// conversationHistory = append(conversationHistory, schema.Message{ +// Role: string(item.Role), +// StringContent: content.Text, +// Content: content.Text, +// }) +// case "input_audio": +// // We do not to turn to text here the audio result. +// // When generating it later on from the LLM, +// // we will also generate text and return it and store it in the conversation +// // Here we just want to get the user audio if there is any as a new input for the conversation. +// if item.Role == "user" { +// latestUserAudio = content.Audio +// } +// } +// } +// } +// +// conversation.Lock.Unlock() +// +// var generatedText string +// var generatedAudio []byte +// var functionCall *FunctionCall +// var err error +// +// if latestUserAudio != "" { +// // Process the latest user audio input +// decodedAudio, err := base64.StdEncoding.DecodeString(latestUserAudio) +// if err != nil { +// log.Error().Msgf("failed to decode latest user audio: %s", err.Error()) +// sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "") +// return +// } +// +// // Process the audio input and generate a response +// generatedText, generatedAudio, functionCall, err = processAudioResponse(session, decodedAudio) +// if err != nil { +// log.Error().Msgf("failed to process audio response: %s", err.Error()) +// sendError(c, "processing_error", "Failed to generate audio response", "", "") +// return +// } +// } else { +// +// if session.Instructions != "" { +// conversationHistory = append([]schema.Message{{ +// Role: "system", +// StringContent: session.Instructions, +// Content: session.Instructions, +// }}, conversationHistory...) +// } +// +// funcs := session.Functions +// shouldUseFn := len(funcs) > 0 && config.ShouldUseFunctions() +// +// // Allow the user to set custom actions via config file +// // to be "embedded" in each model +// noActionName := "answer" +// noActionDescription := "use this action to answer without performing any action" +// +// if config.FunctionsConfig.NoActionFunctionName != "" { +// noActionName = config.FunctionsConfig.NoActionFunctionName +// } +// if config.FunctionsConfig.NoActionDescriptionName != "" { +// noActionDescription = config.FunctionsConfig.NoActionDescriptionName +// } +// +// if (!config.FunctionsConfig.GrammarConfig.NoGrammar) && shouldUseFn { +// noActionGrammar := functions.Function{ +// Name: noActionName, +// Description: noActionDescription, +// Parameters: map[string]interface{}{ +// "properties": map[string]interface{}{ +// "message": map[string]interface{}{ +// "type": "string", +// "description": "The message to reply the user with", +// }}, +// }, +// } +// +// // Append the no action function +// if !config.FunctionsConfig.DisableNoAction { +// funcs = append(funcs, noActionGrammar) +// } +// +// // Update input grammar +// jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) +// g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...) +// if err == nil { +// config.Grammar = g +// } +// } +// +// // Generate a response based on text conversation history +// prompt := evaluator.TemplateMessages(conversationHistory, config, funcs, shouldUseFn) +// +// generatedText, functionCall, err = processTextResponse(config, session, prompt) +// if err != nil { +// log.Error().Msgf("failed to process text response: %s", err.Error()) +// sendError(c, "processing_error", "Failed to generate text response", "", "") +// return +// } +// log.Debug().Any("text", generatedText).Msg("Generated text response") +// } +// +// if functionCall != nil { +// // The model wants to call a function +// // Create a function_call item and send it to the client +// item := &Item{ +// ID: generateItemID(), +// Object: "realtime.item", +// Type: "function_call", +// Status: "completed", +// Role: "assistant", +// FunctionCall: functionCall, +// } +// +// // Add item to conversation +// conversation.Lock.Lock() +// conversation.Items = append(conversation.Items, item) +// conversation.Lock.Unlock() +// +// // Send item.created event +// sendEvent(c, OutgoingMessage{ +// Type: "conversation.item.created", +// Item: item, +// }) +// +// // Optionally, you can generate a message to the user indicating the function call +// // For now, we'll assume the client handles the function call and may trigger another response +// +// } else { +// // Send response.stream messages +// if generatedAudio != nil { +// // If generatedAudio is available, send it as audio +// encodedAudio := base64.StdEncoding.EncodeToString(generatedAudio) +// outgoingMsg := OutgoingMessage{ +// Type: "response.stream", +// Audio: encodedAudio, +// } +// sendEvent(c, outgoingMsg) +// } else { +// // Send text response (could be streamed in chunks) +// chunks := splitResponseIntoChunks(generatedText) +// for _, chunk := range chunks { +// outgoingMsg := OutgoingMessage{ +// Type: "response.stream", +// Content: chunk, +// } +// sendEvent(c, outgoingMsg) +// } +// } +// +// // Send response.done message +// sendEvent(c, OutgoingMessage{ +// Type: "response.done", +// }) +// +// // Add the assistant's response to the conversation +// content := []ConversationContent{} +// if generatedAudio != nil { +// content = append(content, ConversationContent{ +// Type: "audio", +// Audio: base64.StdEncoding.EncodeToString(generatedAudio), +// }) +// // Optionally include a text transcript +// if generatedText != "" { +// content = append(content, ConversationContent{ +// Type: "text", +// Text: generatedText, +// }) +// } +// } else { +// content = append(content, ConversationContent{ +// Type: "text", +// Text: generatedText, +// }) +// } +// +// item := &Item{ +// ID: generateItemID(), +// Object: "realtime.item", +// Type: "message", +// Status: "completed", +// Role: "assistant", +// Content: content, +// } +// +// // Add item to conversation +// conversation.Lock.Lock() +// conversation.Items = append(conversation.Items, item) +// conversation.Lock.Unlock() +// +// // Send item.created event +// sendEvent(c, OutgoingMessage{ +// Type: "conversation.item.created", +// Item: item, +// }) +// +// log.Debug().Any("item", item).Msg("Realtime response sent") +// } +// } + +// Function to process text response and detect function calls +func processTextResponse(config *config.BackendConfig, session *Session, prompt string) (string, *FunctionCall, error) { + + // Placeholder implementation + // Replace this with actual model inference logic using session.Model and prompt + // For example, the model might return a special token or JSON indicating a function call + + /* + predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil) + + result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) { + if !shouldUseFn { + // no function is called, just reply and use stop as finish reason + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + return + } + + textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig) + s = functions.CleanupLLMResult(s, config.FunctionsConfig) + results := functions.ParseFunctionCall(s, config.FunctionsConfig) + log.Debug().Msgf("Text content to return: %s", textContentToReturn) + noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0 + + switch { + case noActionsToRun: + result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput) + if err != nil { + log.Error().Err(err).Msg("error handling question") + return + } + *c = append(*c, schema.Choice{ + Message: &schema.Message{Role: "assistant", Content: &result}}) + default: + toolChoice := schema.Choice{ + Message: &schema.Message{ + Role: "assistant", + }, + } + + if len(input.Tools) > 0 { + toolChoice.FinishReason = "tool_calls" + } + + for _, ss := range results { + name, args := ss.Name, ss.Arguments + if len(input.Tools) > 0 { + // If we are using tools, we condense the function calls into + // a single response choice with all the tools + toolChoice.Message.Content = textContentToReturn + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + ) + } else { + // otherwise we return more choices directly + *c = append(*c, schema.Choice{ + FinishReason: "function_call", + Message: &schema.Message{ + Role: "assistant", + Content: &textContentToReturn, + FunctionCall: map[string]interface{}{ + "name": name, + "arguments": args, + }, + }, + }) + } + } + + if len(input.Tools) > 0 { + // we need to append our result if we are using tools + *c = append(*c, toolChoice) + } + } + + }, nil) + if err != nil { + return err + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) + + */ + + // TODO: use session.ModelInterface... + // Simulate a function call + if strings.Contains(prompt, "weather") { + functionCall := &FunctionCall{ + Name: "get_weather", + Arguments: map[string]interface{}{ + "location": "New York", + "scale": "celsius", + }, + } + return "", functionCall, nil + } + + // Otherwise, return a normal text response + return "This is a generated response based on the conversation.", nil, nil +} + +// Function to process audio response and detect function calls +func processAudioResponse(session *Session, audioData []byte) (string, []byte, *FunctionCall, error) { + // TODO: Do the below or use an any-to-any model like Qwen Omni + // Implement the actual model inference logic using session.Model and audioData + // For example: + // 1. Transcribe the audio to text + // 2. Generate a response based on the transcribed text + // 3. Check if the model wants to call a function + // 4. Convert the response text to speech (audio) + // + // Placeholder implementation: + + // TODO: template eventual messages, like chat.go + reply, err := session.ModelInterface.Predict(context.Background(), &proto.PredictOptions{ + Prompt: "What's the weather in New York?", + }) + + if err != nil { + return "", nil, nil, err + } + + generatedAudio := reply.Audio + + transcribedText := "What's the weather in New York?" + var functionCall *FunctionCall + + // Simulate a function call + if strings.Contains(transcribedText, "weather") { + functionCall = &FunctionCall{ + Name: "get_weather", + Arguments: map[string]interface{}{ + "location": "New York", + "scale": "celsius", + }, + } + return "", nil, functionCall, nil + } + + // Generate a response + generatedText := "This is a response to your speech input." + + return generatedText, generatedAudio, nil, nil +} + +// Function to split the response into chunks (for streaming) +func splitResponseIntoChunks(response string) []string { + // Split the response into chunks of fixed size + chunkSize := 50 // characters per chunk + var chunks []string + for len(response) > 0 { + if len(response) > chunkSize { + chunks = append(chunks, response[:chunkSize]) + response = response[chunkSize:] + } else { + chunks = append(chunks, response) + break + } + } + return chunks +} + +// Helper functions to generate unique IDs +func generateSessionID() string { + // Generate a unique session ID + // Implement as needed + return "sess_" + generateUniqueID() +} + +func generateConversationID() string { + // Generate a unique conversation ID + // Implement as needed + return "conv_" + generateUniqueID() +} + +func generateItemID() string { + // Generate a unique item ID + // Implement as needed + return "item_" + generateUniqueID() +} + +func generateUniqueID() string { + // Generate a unique ID string + // For simplicity, use a counter or UUID + // Implement as needed + return "unique_id" +} + +// Structures for 'response.create' messages +type ResponseCreate struct { + Modalities []string `json:"modalities,omitempty"` + Instructions string `json:"instructions,omitempty"` + Functions functions.Functions `json:"functions,omitempty"` + // Other fields as needed +} diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go new file mode 100644 index 00000000..aeab31ad --- /dev/null +++ b/core/http/endpoints/openai/realtime_model.go @@ -0,0 +1,259 @@ +package openai + +import ( + "context" + "fmt" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + grpcClient "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" + "github.com/rs/zerolog/log" + "google.golang.org/grpc" +) + +var ( + _ Model = new(wrappedModel) + _ Model = new(anyToAnyModel) +) + +// wrappedModel represent a model which does not support Any-to-Any operations +// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods +// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS) +type wrappedModel struct { + TTSConfig *config.BackendConfig + TranscriptionConfig *config.BackendConfig + LLMConfig *config.BackendConfig + TTSClient grpcClient.Backend + TranscriptionClient grpcClient.Backend + LLMClient grpcClient.Backend + + VADConfig *config.BackendConfig + VADClient grpcClient.Backend +} + +// anyToAnyModel represent a model which supports Any-to-Any operations +// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model. +// In the future there could be models that accept continous audio input only so this design will be useful for that +type anyToAnyModel struct { + LLMConfig *config.BackendConfig + LLMClient grpcClient.Backend + + VADConfig *config.BackendConfig + VADClient grpcClient.Backend +} + +type transcriptOnlyModel struct { + TranscriptionConfig *config.BackendConfig + TranscriptionClient grpcClient.Backend + VADConfig *config.BackendConfig + VADClient grpcClient.Backend +} + +func (m *transcriptOnlyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { + return m.VADClient.VAD(ctx, in) +} + +func (m *transcriptOnlyModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) { + return m.TranscriptionClient.AudioTranscription(ctx, in, opts...) +} + +func (m *transcriptOnlyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { + return nil, fmt.Errorf("predict operation not supported in transcript-only mode") +} + +func (m *transcriptOnlyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { + return fmt.Errorf("predict stream operation not supported in transcript-only mode") +} + +func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { + return m.VADClient.VAD(ctx, in) +} + +func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { + return m.VADClient.VAD(ctx, in) +} + +func (m *wrappedModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) { + return m.TranscriptionClient.AudioTranscription(ctx, in, opts...) +} + +func (m *anyToAnyModel) Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error) { + // TODO: Can any-to-any models transcribe? + return m.LLMClient.AudioTranscription(ctx, in, opts...) +} + +func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { + // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it) + // sound.BufferAsWAV(audioData, "audio.wav") + + return m.LLMClient.Predict(ctx, in) +} + +func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { + // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it) + + return m.LLMClient.PredictStream(ctx, in, f) +} + +func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { + return m.LLMClient.Predict(ctx, in) +} + +func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { + return m.LLMClient.PredictStream(ctx, in, f) +} + +func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.BackendConfig, error) { + cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath) + if err != nil { + + return nil, nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgVAD.Validate() { + return nil, nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts := backend.ModelOptions(*cfgVAD, appConfig) + VADClient, err := ml.Load(opts...) + if err != nil { + return nil, nil, fmt.Errorf("failed to load tts model: %w", err) + } + + cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath) + if err != nil { + + return nil, nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgSST.Validate() { + return nil, nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts = backend.ModelOptions(*cfgSST, appConfig) + transcriptionClient, err := ml.Load(opts...) + if err != nil { + return nil, nil, fmt.Errorf("failed to load SST model: %w", err) + } + + return &transcriptOnlyModel{ + VADConfig: cfgVAD, + VADClient: VADClient, + TranscriptionConfig: cfgSST, + TranscriptionClient: transcriptionClient, + }, cfgSST, nil +} + +// returns and loads either a wrapped model or a model that support audio-to-audio +func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) { + + cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgVAD.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts := backend.ModelOptions(*cfgVAD, appConfig) + VADClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load tts model: %w", err) + } + + // TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process + cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgSST.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts = backend.ModelOptions(*cfgSST, appConfig) + transcriptionClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load SST model: %w", err) + } + + // TODO: Decide when we have a real any-to-any model + if false { + + cfgAnyToAny, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgAnyToAny.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts := backend.ModelOptions(*cfgAnyToAny, appConfig) + anyToAnyClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load tts model: %w", err) + } + + return &anyToAnyModel{ + LLMConfig: cfgAnyToAny, + LLMClient: anyToAnyClient, + VADConfig: cfgVAD, + VADClient: VADClient, + }, nil + } + + log.Debug().Msg("Loading a wrapped model") + + // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations + cfgLLM, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgLLM.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + cfgTTS, err := cl.LoadBackendConfigFileByName(pipeline.TTS, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgTTS.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + + opts = backend.ModelOptions(*cfgTTS, appConfig) + ttsClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load tts model: %w", err) + } + + opts = backend.ModelOptions(*cfgLLM, appConfig) + llmClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load LLM model: %w", err) + } + + return &wrappedModel{ + TTSConfig: cfgTTS, + TranscriptionConfig: cfgSST, + LLMConfig: cfgLLM, + TTSClient: ttsClient, + TranscriptionClient: transcriptionClient, + LLMClient: llmClient, + + VADConfig: cfgVAD, + VADClient: VADClient, + }, nil +} diff --git a/core/http/endpoints/openai/types/realtime.go b/core/http/endpoints/openai/types/realtime.go new file mode 100644 index 00000000..2da0600b --- /dev/null +++ b/core/http/endpoints/openai/types/realtime.go @@ -0,0 +1,1178 @@ +package types + +// Most of this file was coppied from https://github.com/WqyJh/go-openai-realtime +// Copyright (c) 2024 Qiying Wang MIT License + +import ( + "encoding/json" + "fmt" + "math" +) + +const ( + // Inf is the maximum value for an IntOrInf. + Inf IntOrInf = math.MaxInt +) + +// IntOrInf is a type that can be either an int or "inf". +type IntOrInf int + +// IsInf returns true if the value is "inf". +func (m IntOrInf) IsInf() bool { + return m == Inf +} + +// MarshalJSON marshals the IntOrInf to JSON. +func (m IntOrInf) MarshalJSON() ([]byte, error) { + if m == Inf { + return []byte("\"inf\""), nil + } + return json.Marshal(int(m)) +} + +// UnmarshalJSON unmarshals the IntOrInf from JSON. +func (m *IntOrInf) UnmarshalJSON(data []byte) error { + if string(data) == "\"inf\"" { + *m = Inf + return nil + } + if len(data) == 0 { + return nil + } + return json.Unmarshal(data, (*int)(m)) +} + +type AudioFormat string + +const ( + AudioFormatPcm16 AudioFormat = "pcm16" + AudioFormatG711Ulaw AudioFormat = "g711_ulaw" + AudioFormatG711Alaw AudioFormat = "g711_alaw" +) + +type Modality string + +const ( + ModalityText Modality = "text" + ModalityAudio Modality = "audio" +) + +type ClientTurnDetectionType string + +const ( + ClientTurnDetectionTypeServerVad ClientTurnDetectionType = "server_vad" +) + +type ServerTurnDetectionType string + +const ( + ServerTurnDetectionTypeNone ServerTurnDetectionType = "none" + ServerTurnDetectionTypeServerVad ServerTurnDetectionType = "server_vad" +) + +type TurnDetectionType string + +const ( + // TurnDetectionTypeNone means turn detection is disabled. + // This can only be used in ServerSession, not in ClientSession. + // If you want to disable turn detection, you should send SessionUpdateEvent with TurnDetection set to nil. + TurnDetectionTypeNone TurnDetectionType = "none" + // TurnDetectionTypeServerVad use server-side VAD to detect turn. + // This is default value for newly created session. + TurnDetectionTypeServerVad TurnDetectionType = "server_vad" +) + +type TurnDetectionParams struct { + // Activation threshold for VAD. + Threshold float64 `json:"threshold,omitempty"` + // Audio included before speech starts (in milliseconds). + PrefixPaddingMs int `json:"prefix_padding_ms,omitempty"` + // Duration of silence to detect speech stop (in milliseconds). + SilenceDurationMs int `json:"silence_duration_ms,omitempty"` + // Whether or not to automatically generate a response when VAD is enabled. true by default. + CreateResponse *bool `json:"create_response,omitempty"` +} + +type ClientTurnDetection struct { + // Type of turn detection, only "server_vad" is currently supported. + Type ClientTurnDetectionType `json:"type"` + + TurnDetectionParams +} + +type ServerTurnDetection struct { + // The type of turn detection ("server_vad" or "none"). + Type ServerTurnDetectionType `json:"type"` + + TurnDetectionParams +} + +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +type ToolChoiceInterface interface { + ToolChoice() +} + +type ToolChoiceString string + +func (ToolChoiceString) ToolChoice() {} + +const ( + ToolChoiceAuto ToolChoiceString = "auto" + ToolChoiceNone ToolChoiceString = "none" + ToolChoiceRequired ToolChoiceString = "required" +) + +type ToolChoice struct { + Type ToolType `json:"type"` + Function ToolFunction `json:"function,omitempty"` +} + +func (t ToolChoice) ToolChoice() {} + +type ToolFunction struct { + Name string `json:"name"` +} + +type MessageRole string + +const ( + MessageRoleSystem MessageRole = "system" + MessageRoleAssistant MessageRole = "assistant" + MessageRoleUser MessageRole = "user" +) + +type InputAudioTranscription struct { + // The model used for transcription. + Model string `json:"model"` + Language string `json:"language,omitempty"` + Prompt string `json:"prompt,omitempty"` +} + +type Tool struct { + Type ToolType `json:"type"` + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` +} + +type MessageItemType string + +const ( + MessageItemTypeMessage MessageItemType = "message" + MessageItemTypeFunctionCall MessageItemType = "function_call" + MessageItemTypeFunctionCallOutput MessageItemType = "function_call_output" +) + +type MessageContentType string + +const ( + MessageContentTypeText MessageContentType = "text" + MessageContentTypeAudio MessageContentType = "audio" + MessageContentTypeTranscript MessageContentType = "transcript" + MessageContentTypeInputText MessageContentType = "input_text" + MessageContentTypeInputAudio MessageContentType = "input_audio" +) + +type MessageContentPart struct { + // The content type. + Type MessageContentType `json:"type"` + // The text content. Validated if type is text. + Text string `json:"text,omitempty"` + // Base64-encoded audio data. Validated if type is audio. + Audio string `json:"audio,omitempty"` + // The transcript of the audio. Validated if type is transcript. + Transcript string `json:"transcript,omitempty"` +} + +type MessageItem struct { + // The unique ID of the item. + ID string `json:"id,omitempty"` + // The type of the item ("message", "function_call", "function_call_output"). + Type MessageItemType `json:"type"` + // The final status of the item. + Status ItemStatus `json:"status,omitempty"` + // The role associated with the item. + Role MessageRole `json:"role,omitempty"` + // The content of the item. + Content []MessageContentPart `json:"content,omitempty"` + // The ID of the function call, if the item is a function call. + CallID string `json:"call_id,omitempty"` + // The name of the function, if the item is a function call. + Name string `json:"name,omitempty"` + // The arguments of the function, if the item is a function call. + Arguments string `json:"arguments,omitempty"` + // The output of the function, if the item is a function call output. + Output string `json:"output,omitempty"` +} + +type ResponseMessageItem struct { + MessageItem + // The object type, must be "realtime.item". + Object string `json:"object,omitempty"` +} + +type Error struct { + // The type of error (e.g., "invalid_request_error", "server_error"). + Message string `json:"message,omitempty"` + // Error code, if any. + Type string `json:"type,omitempty"` + // A human-readable error message. + Code string `json:"code,omitempty"` + // Parameter related to the error, if any. + Param string `json:"param,omitempty"` + // The event_id of the client event that caused the error, if applicable. + EventID string `json:"event_id,omitempty"` +} + +// ServerToolChoice is a type that can be used to choose a tool response from the server. +type ServerToolChoice struct { + String ToolChoiceString + Function ToolChoice +} + +// UnmarshalJSON is a custom unmarshaler for ServerToolChoice. +func (m *ServerToolChoice) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &m.Function) + if err != nil { + if data[0] == '"' { + data = data[1:] + } + if data[len(data)-1] == '"' { + data = data[:len(data)-1] + } + m.String = ToolChoiceString(data) + m.Function = ToolChoice{} + return nil + } + return nil +} + +// IsFunction returns true if the tool choice is a function call. +func (m *ServerToolChoice) IsFunction() bool { + return m.Function.Type == ToolTypeFunction +} + +// Get returns the ToolChoiceInterface based on the type of tool choice. +func (m ServerToolChoice) Get() ToolChoiceInterface { + if m.IsFunction() { + return m.Function + } + return m.String +} + +type ServerSession struct { + // The unique ID of the session. + ID string `json:"id"` + // The object type, must be "realtime.session". + Object string `json:"object"` + // The default model used for this session. + Model string `json:"model"` + // The set of modalities the model can respond with. + Modalities []Modality `json:"modalities,omitempty"` + // The default system instructions. + Instructions string `json:"instructions,omitempty"` + // The voice the model uses to respond - one of alloy, echo, or shimmer. + Voice string `json:"voice,omitempty"` + // The format of input audio. + InputAudioFormat AudioFormat `json:"input_audio_format,omitempty"` + // The format of output audio. + OutputAudioFormat AudioFormat `json:"output_audio_format,omitempty"` + // Configuration for input audio transcription. + InputAudioTranscription *InputAudioTranscription `json:"input_audio_transcription,omitempty"` + // Configuration for turn detection. + TurnDetection *ServerTurnDetection `json:"turn_detection,omitempty"` + // Tools (functions) available to the model. + Tools []Tool `json:"tools,omitempty"` + // How the model chooses tools. + ToolChoice ServerToolChoice `json:"tool_choice,omitempty"` + // Sampling temperature. + Temperature *float32 `json:"temperature,omitempty"` + // Maximum number of output tokens. + MaxOutputTokens IntOrInf `json:"max_response_output_tokens,omitempty"` +} + +type ItemStatus string + +const ( + ItemStatusInProgress ItemStatus = "in_progress" + ItemStatusCompleted ItemStatus = "completed" + ItemStatusIncomplete ItemStatus = "incomplete" +) + +type Conversation struct { + // The unique ID of the conversation. + ID string `json:"id"` + // The object type, must be "realtime.conversation". + Object string `json:"object"` +} + +type ResponseStatus string + +const ( + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusCancelled ResponseStatus = "cancelled" + ResponseStatusIncomplete ResponseStatus = "incomplete" + ResponseStatusFailed ResponseStatus = "failed" +) + +type CachedTokensDetails struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` +} + +type InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + CachedTokensDetails CachedTokensDetails `json:"cached_tokens_details,omitempty"` +} + +type OutputTokenDetails struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` +} + +type Usage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + // Input token details. + InputTokenDetails InputTokenDetails `json:"input_token_details,omitempty"` + // Output token details. + OutputTokenDetails OutputTokenDetails `json:"output_token_details,omitempty"` +} + +type Response struct { + // The unique ID of the response. + ID string `json:"id"` + // The object type, must be "realtime.response". + Object string `json:"object"` + // The status of the response. + Status ResponseStatus `json:"status"` + // Additional details about the status. + StatusDetails any `json:"status_details,omitempty"` + // The list of output items generated by the response. + Output []ResponseMessageItem `json:"output"` + // Usage statistics for the response. + Usage *Usage `json:"usage,omitempty"` +} + +type RateLimit struct { + // The name of the rate limit ("requests", "tokens", "input_tokens", "output_tokens"). + Name string `json:"name"` + // The maximum allowed value for the rate limit. + Limit int `json:"limit"` + // The remaining value before the limit is reached. + Remaining int `json:"remaining"` + // Seconds until the rate limit resets. + ResetSeconds float64 `json:"reset_seconds"` +} + +// ClientEventType is the type of client event. See https://platform.openai.com/docs/guides/realtime/client-events +type ClientEventType string + +const ( + ClientEventTypeSessionUpdate ClientEventType = "session.update" + ClientEventTypeTranscriptionSessionUpdate ClientEventType = "transcription_session.update" + ClientEventTypeInputAudioBufferAppend ClientEventType = "input_audio_buffer.append" + ClientEventTypeInputAudioBufferCommit ClientEventType = "input_audio_buffer.commit" + ClientEventTypeInputAudioBufferClear ClientEventType = "input_audio_buffer.clear" + ClientEventTypeConversationItemCreate ClientEventType = "conversation.item.create" + ClientEventTypeConversationItemTruncate ClientEventType = "conversation.item.truncate" + ClientEventTypeConversationItemDelete ClientEventType = "conversation.item.delete" + ClientEventTypeResponseCreate ClientEventType = "response.create" + ClientEventTypeResponseCancel ClientEventType = "response.cancel" +) + +// ClientEvent is the interface for client event. +type ClientEvent interface { + ClientEventType() ClientEventType +} + +// EventBase is the base struct for all client events. +type EventBase struct { + // Optional client-generated ID used to identify this event. + EventID string `json:"event_id,omitempty"` +} + +type ClientSession struct { + Model string `json:"model,omitempty"` + // The set of modalities the model can respond with. To disable audio, set this to ["text"]. + Modalities []Modality `json:"modalities,omitempty"` + // The default system instructions prepended to model calls. + Instructions string `json:"instructions,omitempty"` + // The voice the model uses to respond - one of alloy, echo, or shimmer. Cannot be changed once the model has responded with audio at least once. + Voice string `json:"voice,omitempty"` + // The format of input audio. Options are "pcm16", "g711_ulaw", or "g711_alaw". + InputAudioFormat AudioFormat `json:"input_audio_format,omitempty"` + // The format of output audio. Options are "pcm16", "g711_ulaw", or "g711_alaw". + OutputAudioFormat AudioFormat `json:"output_audio_format,omitempty"` + // Configuration for input audio transcription. Can be set to `nil` to turn off. + InputAudioTranscription *InputAudioTranscription `json:"input_audio_transcription,omitempty"` + // Configuration for turn detection. Can be set to `nil` to turn off. + TurnDetection *ClientTurnDetection `json:"turn_detection"` + // Tools (functions) available to the model. + Tools []Tool `json:"tools,omitempty"` + // How the model chooses tools. Options are "auto", "none", "required", or specify a function. + ToolChoice ToolChoiceInterface `json:"tool_choice,omitempty"` + // Sampling temperature for the model. + Temperature *float32 `json:"temperature,omitempty"` + // Maximum number of output tokens for a single assistant response, inclusive of tool calls. Provide an integer between 1 and 4096 to limit output tokens, or "inf" for the maximum available tokens for a given model. Defaults to "inf". + MaxOutputTokens IntOrInf `json:"max_response_output_tokens,omitempty"` +} + +type CreateSessionRequest struct { + ClientSession + + // The Realtime model used for this session. + Model string `json:"model,omitempty"` +} + +type ClientSecret struct { + // Ephemeral key usable in client environments to authenticate connections to the Realtime API. Use this in client-side environments rather than a standard API token, which should only be used server-side. + Value string `json:"value"` + // Timestamp for when the token expires. Currently, all tokens expire after one minute. + ExpiresAt int64 `json:"expires_at"` +} + +type CreateSessionResponse struct { + ServerSession + + // Ephemeral key returned by the API. + ClientSecret ClientSecret `json:"client_secret"` +} + +// SessionUpdateEvent is the event for session update. +// Send this event to update the session’s default configuration. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/session/update +type SessionUpdateEvent struct { + EventBase + // Session configuration to update. + Session ClientSession `json:"session"` +} + +func (m SessionUpdateEvent) ClientEventType() ClientEventType { + return ClientEventTypeSessionUpdate +} + +func (m SessionUpdateEvent) MarshalJSON() ([]byte, error) { + type sessionUpdateEvent SessionUpdateEvent + v := struct { + *sessionUpdateEvent + Type ClientEventType `json:"type"` + }{ + sessionUpdateEvent: (*sessionUpdateEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// InputAudioBufferAppendEvent is the event for input audio buffer append. +// Send this event to append audio bytes to the input audio buffer. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/append +type InputAudioBufferAppendEvent struct { + EventBase + Audio string `json:"audio"` // Base64-encoded audio bytes. +} + +func (m InputAudioBufferAppendEvent) ClientEventType() ClientEventType { + return ClientEventTypeInputAudioBufferAppend +} + +func (m InputAudioBufferAppendEvent) MarshalJSON() ([]byte, error) { + type inputAudioBufferAppendEvent InputAudioBufferAppendEvent + v := struct { + *inputAudioBufferAppendEvent + Type ClientEventType `json:"type"` + }{ + inputAudioBufferAppendEvent: (*inputAudioBufferAppendEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// InputAudioBufferCommitEvent is the event for input audio buffer commit. +// Send this event to commit audio bytes to a user message. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/commit +type InputAudioBufferCommitEvent struct { + EventBase +} + +func (m InputAudioBufferCommitEvent) ClientEventType() ClientEventType { + return ClientEventTypeInputAudioBufferCommit +} + +func (m InputAudioBufferCommitEvent) MarshalJSON() ([]byte, error) { + type inputAudioBufferCommitEvent InputAudioBufferCommitEvent + v := struct { + *inputAudioBufferCommitEvent + Type ClientEventType `json:"type"` + }{ + inputAudioBufferCommitEvent: (*inputAudioBufferCommitEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// InputAudioBufferClearEvent is the event for input audio buffer clear. +// Send this event to clear the audio bytes in the buffer. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/clear +type InputAudioBufferClearEvent struct { + EventBase +} + +func (m InputAudioBufferClearEvent) ClientEventType() ClientEventType { + return ClientEventTypeInputAudioBufferClear +} + +func (m InputAudioBufferClearEvent) MarshalJSON() ([]byte, error) { + type inputAudioBufferClearEvent InputAudioBufferClearEvent + v := struct { + *inputAudioBufferClearEvent + Type ClientEventType `json:"type"` + }{ + inputAudioBufferClearEvent: (*inputAudioBufferClearEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// ConversationItemCreateEvent is the event for conversation item create. +// Send this event when adding an item to the conversation. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/create +type ConversationItemCreateEvent struct { + EventBase + // The ID of the preceding item after which the new item will be inserted. + PreviousItemID string `json:"previous_item_id,omitempty"` + // The item to add to the conversation. + Item MessageItem `json:"item"` +} + +func (m ConversationItemCreateEvent) ClientEventType() ClientEventType { + return ClientEventTypeConversationItemCreate +} + +func (m ConversationItemCreateEvent) MarshalJSON() ([]byte, error) { + type conversationItemCreateEvent ConversationItemCreateEvent + v := struct { + *conversationItemCreateEvent + Type ClientEventType `json:"type"` + }{ + conversationItemCreateEvent: (*conversationItemCreateEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// ConversationItemTruncateEvent is the event for conversation item truncate. +// Send this event when you want to truncate a previous assistant message’s audio. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/truncate +type ConversationItemTruncateEvent struct { + EventBase + // The ID of the assistant message item to truncate. + ItemID string `json:"item_id"` + // The index of the content part to truncate. + ContentIndex int `json:"content_index"` + // Inclusive duration up to which audio is truncated, in milliseconds. + AudioEndMs int `json:"audio_end_ms"` +} + +func (m ConversationItemTruncateEvent) ClientEventType() ClientEventType { + return ClientEventTypeConversationItemTruncate +} + +func (m ConversationItemTruncateEvent) MarshalJSON() ([]byte, error) { + type conversationItemTruncateEvent ConversationItemTruncateEvent + v := struct { + *conversationItemTruncateEvent + Type ClientEventType `json:"type"` + }{ + conversationItemTruncateEvent: (*conversationItemTruncateEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// ConversationItemDeleteEvent is the event for conversation item delete. +// Send this event when you want to remove any item from the conversation history. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/delete +type ConversationItemDeleteEvent struct { + EventBase + // The ID of the item to delete. + ItemID string `json:"item_id"` +} + +func (m ConversationItemDeleteEvent) ClientEventType() ClientEventType { + return ClientEventTypeConversationItemDelete +} + +func (m ConversationItemDeleteEvent) MarshalJSON() ([]byte, error) { + type conversationItemDeleteEvent ConversationItemDeleteEvent + v := struct { + *conversationItemDeleteEvent + Type ClientEventType `json:"type"` + }{ + conversationItemDeleteEvent: (*conversationItemDeleteEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +type ResponseCreateParams struct { + // The modalities for the response. + Modalities []Modality `json:"modalities,omitempty"` + // Instructions for the model. + Instructions string `json:"instructions,omitempty"` + // The voice the model uses to respond - one of alloy, echo, or shimmer. + Voice string `json:"voice,omitempty"` + // The format of output audio. + OutputAudioFormat AudioFormat `json:"output_audio_format,omitempty"` + // Tools (functions) available to the model. + Tools []Tool `json:"tools,omitempty"` + // How the model chooses tools. + ToolChoice ToolChoiceInterface `json:"tool_choice,omitempty"` + // Sampling temperature. + Temperature *float32 `json:"temperature,omitempty"` + // Maximum number of output tokens for a single assistant response, inclusive of tool calls. Provide an integer between 1 and 4096 to limit output tokens, or "inf" for the maximum available tokens for a given model. Defaults to "inf". + MaxOutputTokens IntOrInf `json:"max_output_tokens,omitempty"` +} + +// ResponseCreateEvent is the event for response create. +// Send this event to trigger a response generation. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/response/create +type ResponseCreateEvent struct { + EventBase + // Configuration for the response. + Response ResponseCreateParams `json:"response"` +} + +func (m ResponseCreateEvent) ClientEventType() ClientEventType { + return ClientEventTypeResponseCreate +} + +func (m ResponseCreateEvent) MarshalJSON() ([]byte, error) { + type responseCreateEvent ResponseCreateEvent + v := struct { + *responseCreateEvent + Type ClientEventType `json:"type"` + }{ + responseCreateEvent: (*responseCreateEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// ResponseCancelEvent is the event for response cancel. +// Send this event to cancel an in-progress response. +// See https://platform.openai.com/docs/api-reference/realtime-client-events/response/cancel +type ResponseCancelEvent struct { + EventBase + // A specific response ID to cancel - if not provided, will cancel an in-progress response in the default conversation. + ResponseID string `json:"response_id,omitempty"` +} + +func (m ResponseCancelEvent) ClientEventType() ClientEventType { + return ClientEventTypeResponseCancel +} + +func (m ResponseCancelEvent) MarshalJSON() ([]byte, error) { + type responseCancelEvent ResponseCancelEvent + v := struct { + *responseCancelEvent + Type ClientEventType `json:"type"` + }{ + responseCancelEvent: (*responseCancelEvent)(&m), + Type: m.ClientEventType(), + } + return json.Marshal(v) +} + +// MarshalClientEvent marshals the client event to JSON. +func MarshalClientEvent(event ClientEvent) ([]byte, error) { + return json.Marshal(event) +} + +type ServerEventType string + +const ( + ServerEventTypeError ServerEventType = "error" + ServerEventTypeSessionCreated ServerEventType = "session.created" + ServerEventTypeSessionUpdated ServerEventType = "session.updated" + ServerEventTypeTranscriptionSessionUpdated ServerEventType = "transcription_session.updated" + ServerEventTypeConversationCreated ServerEventType = "conversation.created" + ServerEventTypeInputAudioBufferCommitted ServerEventType = "input_audio_buffer.committed" + ServerEventTypeInputAudioBufferCleared ServerEventType = "input_audio_buffer.cleared" + ServerEventTypeInputAudioBufferSpeechStarted ServerEventType = "input_audio_buffer.speech_started" + ServerEventTypeInputAudioBufferSpeechStopped ServerEventType = "input_audio_buffer.speech_stopped" + ServerEventTypeConversationItemCreated ServerEventType = "conversation.item.created" + ServerEventTypeConversationItemInputAudioTranscriptionCompleted ServerEventType = "conversation.item.input_audio_transcription.completed" + ServerEventTypeConversationItemInputAudioTranscriptionFailed ServerEventType = "conversation.item.input_audio_transcription.failed" + ServerEventTypeConversationItemTruncated ServerEventType = "conversation.item.truncated" + ServerEventTypeConversationItemDeleted ServerEventType = "conversation.item.deleted" + ServerEventTypeResponseCreated ServerEventType = "response.created" + ServerEventTypeResponseDone ServerEventType = "response.done" + ServerEventTypeResponseOutputItemAdded ServerEventType = "response.output_item.added" + ServerEventTypeResponseOutputItemDone ServerEventType = "response.output_item.done" + ServerEventTypeResponseContentPartAdded ServerEventType = "response.content_part.added" + ServerEventTypeResponseContentPartDone ServerEventType = "response.content_part.done" + ServerEventTypeResponseTextDelta ServerEventType = "response.text.delta" + ServerEventTypeResponseTextDone ServerEventType = "response.text.done" + ServerEventTypeResponseAudioTranscriptDelta ServerEventType = "response.audio_transcript.delta" + ServerEventTypeResponseAudioTranscriptDone ServerEventType = "response.audio_transcript.done" + ServerEventTypeResponseAudioDelta ServerEventType = "response.audio.delta" + ServerEventTypeResponseAudioDone ServerEventType = "response.audio.done" + ServerEventTypeResponseFunctionCallArgumentsDelta ServerEventType = "response.function_call_arguments.delta" + ServerEventTypeResponseFunctionCallArgumentsDone ServerEventType = "response.function_call_arguments.done" + ServerEventTypeRateLimitsUpdated ServerEventType = "rate_limits.updated" +) + +// ServerEvent is the interface for server events. +type ServerEvent interface { + ServerEventType() ServerEventType +} + +// ServerEventBase is the base struct for all server events. +type ServerEventBase struct { + // The unique ID of the server event. + EventID string `json:"event_id,omitempty"` + // The type of the server event. + Type ServerEventType `json:"type"` +} + +func (m ServerEventBase) ServerEventType() ServerEventType { + return m.Type +} + +// ErrorEvent is the event for error. +// Returned when an error occurs. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/error +type ErrorEvent struct { + ServerEventBase + // Details of the error. + Error Error `json:"error"` +} + +// SessionCreatedEvent is the event for session created. +// Returned when a session is created. Emitted automatically when a new connection is established. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/session/created +type SessionCreatedEvent struct { + ServerEventBase + // The session resource. + Session ServerSession `json:"session"` +} + +// SessionUpdatedEvent is the event for session updated. +// Returned when a session is updated. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/session/updated +type SessionUpdatedEvent struct { + ServerEventBase + // The updated session resource. + Session ServerSession `json:"session"` +} + +// ConversationCreatedEvent is the event for conversation created. +// Returned when a conversation is created. Emitted right after session creation. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/created +type ConversationCreatedEvent struct { + ServerEventBase + // The conversation resource. + Conversation Conversation `json:"conversation"` +} + +// InputAudioBufferCommittedEvent is the event for input audio buffer committed. +// Returned when an input audio buffer is committed, either by the client or automatically in server VAD mode. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/committed +type InputAudioBufferCommittedEvent struct { + ServerEventBase + // The ID of the preceding item after which the new item will be inserted. + PreviousItemID string `json:"previous_item_id,omitempty"` + // The ID of the user message item that will be created. + ItemID string `json:"item_id"` +} + +// InputAudioBufferClearedEvent is the event for input audio buffer cleared. +// Returned when the input audio buffer is cleared by the client. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/cleared +type InputAudioBufferClearedEvent struct { + ServerEventBase +} + +// InputAudioBufferSpeechStartedEvent is the event for input audio buffer speech started. +// Returned in server turn detection mode when speech is detected. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/speech_started +type InputAudioBufferSpeechStartedEvent struct { + ServerEventBase + // Milliseconds since the session started when speech was detected. + AudioStartMs int64 `json:"audio_start_ms"` + // The ID of the user message item that will be created when speech stops. + ItemID string `json:"item_id"` +} + +// InputAudioBufferSpeechStoppedEvent is the event for input audio buffer speech stopped. +// Returned in server turn detection mode when speech stops. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/speech_stopped +type InputAudioBufferSpeechStoppedEvent struct { + ServerEventBase + // Milliseconds since the session started when speech stopped. + AudioEndMs int64 `json:"audio_end_ms"` + // The ID of the user message item that will be created. + ItemID string `json:"item_id"` +} + +type ConversationItemCreatedEvent struct { + ServerEventBase + PreviousItemID string `json:"previous_item_id,omitempty"` + Item ResponseMessageItem `json:"item"` +} + +type ConversationItemInputAudioTranscriptionCompletedEvent struct { + ServerEventBase + ItemID string `json:"item_id"` + ContentIndex int `json:"content_index"` + Transcript string `json:"transcript"` +} + +type ConversationItemInputAudioTranscriptionFailedEvent struct { + ServerEventBase + ItemID string `json:"item_id"` + ContentIndex int `json:"content_index"` + Error Error `json:"error"` +} + +type ConversationItemTruncatedEvent struct { + ServerEventBase + ItemID string `json:"item_id"` // The ID of the assistant message item that was truncated. + ContentIndex int `json:"content_index"` // The index of the content part that was truncated. + AudioEndMs int `json:"audio_end_ms"` // The duration up to which the audio was truncated, in milliseconds. +} + +type ConversationItemDeletedEvent struct { + ServerEventBase + ItemID string `json:"item_id"` // The ID of the item that was deleted. +} + +// ResponseCreatedEvent is the event for response created. +// Returned when a new Response is created. The first event of response creation, where the response is in an initial state of "in_progress". +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/created +type ResponseCreatedEvent struct { + ServerEventBase + // The response resource. + Response Response `json:"response"` +} + +// ResponseDoneEvent is the event for response done. +// Returned when a Response is done streaming. Always emitted, no matter the final state. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/done +type ResponseDoneEvent struct { + ServerEventBase + // The response resource. + Response Response `json:"response"` +} + +// ResponseOutputItemAddedEvent is the event for response output item added. +// Returned when a new Item is created during response generation. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_item/added +type ResponseOutputItemAddedEvent struct { + ServerEventBase + // The ID of the response to which the item belongs. + ResponseID string `json:"response_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The item that was added. + Item ResponseMessageItem `json:"item"` +} + +// ResponseOutputItemDoneEvent is the event for response output item done. +// Returned when an Item is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_item/done +type ResponseOutputItemDoneEvent struct { + ServerEventBase + // The ID of the response to which the item belongs. + ResponseID string `json:"response_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The completed item. + Item ResponseMessageItem `json:"item"` +} + +// ResponseContentPartAddedEvent is the event for response content part added. +// Returned when a new content part is added to an assistant message item during response generation. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/content_part/added +type ResponseContentPartAddedEvent struct { + ServerEventBase + ResponseID string `json:"response_id"` + ItemID string `json:"item_id"` + OutputIndex int `json:"output_index"` + ContentIndex int `json:"content_index"` + Part MessageContentPart `json:"part"` +} + +// ResponseContentPartDoneEvent is the event for response content part done. +// Returned when a content part is done streaming in an assistant message item. Also emitted when a Response is interrupted, incomplete, or cancelled. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/content_part/done +type ResponseContentPartDoneEvent struct { + ServerEventBase + // The ID of the response. + ResponseID string `json:"response_id"` + // The ID of the item to which the content part was added. + ItemID string `json:"item_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The index of the content part in the item's content array. + ContentIndex int `json:"content_index"` + // The content part that was added. + Part MessageContentPart `json:"part"` +} + +// ResponseTextDeltaEvent is the event for response text delta. +// Returned when the text value of a "text" content part is updated. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/text/delta +type ResponseTextDeltaEvent struct { + ServerEventBase + ResponseID string `json:"response_id"` + ItemID string `json:"item_id"` + OutputIndex int `json:"output_index"` + ContentIndex int `json:"content_index"` + Delta string `json:"delta"` +} + +// ResponseTextDoneEvent is the event for response text done. +// Returned when the text value of a "text" content part is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/text/done +type ResponseTextDoneEvent struct { + ServerEventBase + ResponseID string `json:"response_id"` + ItemID string `json:"item_id"` + OutputIndex int `json:"output_index"` + ContentIndex int `json:"content_index"` + Text string `json:"text"` +} + +// ResponseAudioTranscriptDeltaEvent is the event for response audio transcript delta. +// Returned when the model-generated transcription of audio output is updated. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio_transcript/delta +type ResponseAudioTranscriptDeltaEvent struct { + ServerEventBase + // The ID of the response. + ResponseID string `json:"response_id"` + // The ID of the item. + ItemID string `json:"item_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The index of the content part in the item's content array. + ContentIndex int `json:"content_index"` + // The transcript delta. + Delta string `json:"delta"` +} + +// ResponseAudioTranscriptDoneEvent is the event for response audio transcript done. +// Returned when the model-generated transcription of audio output is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio_transcript/done +type ResponseAudioTranscriptDoneEvent struct { + ServerEventBase + // The ID of the response. + ResponseID string `json:"response_id"` + // The ID of the item. + ItemID string `json:"item_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The index of the content part in the item's content array. + ContentIndex int `json:"content_index"` + // The final transcript of the audio. + Transcript string `json:"transcript"` +} + +// ResponseAudioDeltaEvent is the event for response audio delta. +// Returned when the model-generated audio is updated. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio/delta +type ResponseAudioDeltaEvent struct { + ServerEventBase + // The ID of the response. + ResponseID string `json:"response_id"` + // The ID of the item. + ItemID string `json:"item_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The index of the content part in the item's content array. + ContentIndex int `json:"content_index"` + // Base64-encoded audio data delta. + Delta string `json:"delta"` +} + +// ResponseAudioDoneEvent is the event for response audio done. +// Returned when the model-generated audio is done. Also emitted when a Response is interrupted, incomplete, or cancelled. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/audio/done +type ResponseAudioDoneEvent struct { + ServerEventBase + // The ID of the response. + ResponseID string `json:"response_id"` + // The ID of the item. + ItemID string `json:"item_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The index of the content part in the item's content array. + ContentIndex int `json:"content_index"` +} + +// ResponseFunctionCallArgumentsDeltaEvent is the event for response function call arguments delta. +// Returned when the model-generated function call arguments are updated. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/function_call_arguments/delta +type ResponseFunctionCallArgumentsDeltaEvent struct { + ServerEventBase + // The ID of the response. + ResponseID string `json:"response_id"` + // The ID of the item. + ItemID string `json:"item_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The ID of the function call. + CallID string `json:"call_id"` + // The arguments delta as a JSON string. + Delta string `json:"delta"` +} + +// ResponseFunctionCallArgumentsDoneEvent is the event for response function call arguments done. +// Returned when the model-generated function call arguments are done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/response/function_call_arguments/done +type ResponseFunctionCallArgumentsDoneEvent struct { + ServerEventBase + // The ID of the response. + ResponseID string `json:"response_id"` + // The ID of the item. + ItemID string `json:"item_id"` + // The index of the output item in the response. + OutputIndex int `json:"output_index"` + // The ID of the function call. + CallID string `json:"call_id"` + // The final arguments as a JSON string. + Arguments string `json:"arguments"` + // The name of the function. Not shown in API reference but present in the actual event. + Name string `json:"name"` +} + +// RateLimitsUpdatedEvent is the event for rate limits updated. +// Emitted after every "response.done" event to indicate the updated rate limits. +// See https://platform.openai.com/docs/api-reference/realtime-server-events/rate_limits/updated +type RateLimitsUpdatedEvent struct { + ServerEventBase + // List of rate limit information. + RateLimits []RateLimit `json:"rate_limits"` +} + +type ServerEventInterface interface { + ErrorEvent | + SessionCreatedEvent | + SessionUpdatedEvent | + ConversationCreatedEvent | + InputAudioBufferCommittedEvent | + InputAudioBufferClearedEvent | + InputAudioBufferSpeechStartedEvent | + InputAudioBufferSpeechStoppedEvent | + ConversationItemCreatedEvent | + ConversationItemInputAudioTranscriptionCompletedEvent | + ConversationItemInputAudioTranscriptionFailedEvent | + ConversationItemTruncatedEvent | + ConversationItemDeletedEvent | + ResponseCreatedEvent | + ResponseDoneEvent | + ResponseOutputItemAddedEvent | + ResponseOutputItemDoneEvent | + ResponseContentPartAddedEvent | + ResponseContentPartDoneEvent | + ResponseTextDeltaEvent | + ResponseTextDoneEvent | + ResponseAudioTranscriptDeltaEvent | + ResponseAudioTranscriptDoneEvent | + ResponseAudioDeltaEvent | + ResponseAudioDoneEvent | + ResponseFunctionCallArgumentsDeltaEvent | + ResponseFunctionCallArgumentsDoneEvent | + RateLimitsUpdatedEvent +} + +func unmarshalServerEvent[T ServerEventInterface](data []byte) (T, error) { + var t T + err := json.Unmarshal(data, &t) + if err != nil { + return t, err + } + return t, nil +} + +// UnmarshalServerEvent unmarshals the server event from the given JSON data. +func UnmarshalServerEvent(data []byte) (ServerEvent, error) { //nolint:funlen,cyclop // TODO: optimize + var eventType struct { + Type ServerEventType `json:"type"` + } + err := json.Unmarshal(data, &eventType) + if err != nil { + return nil, err + } + switch eventType.Type { + case ServerEventTypeError: + return unmarshalServerEvent[ErrorEvent](data) + case ServerEventTypeSessionCreated: + return unmarshalServerEvent[SessionCreatedEvent](data) + case ServerEventTypeSessionUpdated: + return unmarshalServerEvent[SessionUpdatedEvent](data) + case ServerEventTypeConversationCreated: + return unmarshalServerEvent[ConversationCreatedEvent](data) + case ServerEventTypeInputAudioBufferCommitted: + return unmarshalServerEvent[InputAudioBufferCommittedEvent](data) + case ServerEventTypeInputAudioBufferCleared: + return unmarshalServerEvent[InputAudioBufferClearedEvent](data) + case ServerEventTypeInputAudioBufferSpeechStarted: + return unmarshalServerEvent[InputAudioBufferSpeechStartedEvent](data) + case ServerEventTypeInputAudioBufferSpeechStopped: + return unmarshalServerEvent[InputAudioBufferSpeechStoppedEvent](data) + case ServerEventTypeConversationItemCreated: + return unmarshalServerEvent[ConversationItemCreatedEvent](data) + case ServerEventTypeConversationItemInputAudioTranscriptionCompleted: + return unmarshalServerEvent[ConversationItemInputAudioTranscriptionCompletedEvent](data) + case ServerEventTypeConversationItemInputAudioTranscriptionFailed: + return unmarshalServerEvent[ConversationItemInputAudioTranscriptionFailedEvent](data) + case ServerEventTypeConversationItemTruncated: + return unmarshalServerEvent[ConversationItemTruncatedEvent](data) + case ServerEventTypeConversationItemDeleted: + return unmarshalServerEvent[ConversationItemDeletedEvent](data) + case ServerEventTypeResponseCreated: + return unmarshalServerEvent[ResponseCreatedEvent](data) + case ServerEventTypeResponseDone: + return unmarshalServerEvent[ResponseDoneEvent](data) + case ServerEventTypeResponseOutputItemAdded: + return unmarshalServerEvent[ResponseOutputItemAddedEvent](data) + case ServerEventTypeResponseOutputItemDone: + return unmarshalServerEvent[ResponseOutputItemDoneEvent](data) + case ServerEventTypeResponseContentPartAdded: + return unmarshalServerEvent[ResponseContentPartAddedEvent](data) + case ServerEventTypeResponseContentPartDone: + return unmarshalServerEvent[ResponseContentPartDoneEvent](data) + case ServerEventTypeResponseTextDelta: + return unmarshalServerEvent[ResponseTextDeltaEvent](data) + case ServerEventTypeResponseTextDone: + return unmarshalServerEvent[ResponseTextDoneEvent](data) + case ServerEventTypeResponseAudioTranscriptDelta: + return unmarshalServerEvent[ResponseAudioTranscriptDeltaEvent](data) + case ServerEventTypeResponseAudioTranscriptDone: + return unmarshalServerEvent[ResponseAudioTranscriptDoneEvent](data) + case ServerEventTypeResponseAudioDelta: + return unmarshalServerEvent[ResponseAudioDeltaEvent](data) + case ServerEventTypeResponseAudioDone: + return unmarshalServerEvent[ResponseAudioDoneEvent](data) + case ServerEventTypeResponseFunctionCallArgumentsDelta: + return unmarshalServerEvent[ResponseFunctionCallArgumentsDeltaEvent](data) + case ServerEventTypeResponseFunctionCallArgumentsDone: + return unmarshalServerEvent[ResponseFunctionCallArgumentsDoneEvent](data) + case ServerEventTypeRateLimitsUpdated: + return unmarshalServerEvent[RateLimitsUpdatedEvent](data) + default: + // This should never happen. + return nil, fmt.Errorf("unknown server event type: %s", eventType.Type) + } +} diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index fd17613a..b3f1af59 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -15,6 +15,12 @@ func RegisterOpenAIRoutes(app *fiber.App, application *application.Application) { // openAI compatible API endpoint + // realtime + // TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions + app.Get("/v1/realtime", openai.Realtime(application)) + app.Post("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application)) + app.Post("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application)) + // chat chatChain := []fiber.Handler{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), diff --git a/go.mod b/go.mod index 5567c372..4d1079cb 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,8 @@ require ( github.com/GeertJohan/go.rice v1.0.3 github.com/Masterminds/sprig/v3 v3.3.0 github.com/alecthomas/kong v0.9.0 - github.com/census-instrumentation/opencensus-proto v0.4.1 github.com/charmbracelet/glamour v0.7.0 github.com/chasefleming/elem-go v0.26.0 - github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 github.com/containerd/containerd v1.7.19 github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 github.com/elliotchance/orderedmap/v2 v2.2.0 @@ -25,11 +23,9 @@ require ( github.com/gofiber/template/html/v2 v2.1.2 github.com/gofiber/websocket/v2 v2.2.1 github.com/gofrs/flock v0.12.1 - github.com/golang/protobuf v1.5.4 github.com/google/go-containerregistry v0.19.2 github.com/google/uuid v1.6.0 github.com/gpustack/gguf-parser-go v0.17.0 - github.com/grpc-ecosystem/grpc-gateway v1.5.0 github.com/hpcloud/tail v1.0.0 github.com/ipfs/go-log v1.0.5 github.com/jaypipes/ghw v0.12.0 @@ -43,7 +39,6 @@ require ( github.com/nikolalohinski/gonja/v2 v2.3.2 github.com/onsi/ginkgo/v2 v2.22.2 github.com/onsi/gomega v1.36.2 - github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e github.com/otiai10/openaigo v1.7.0 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/prometheus/client_golang v1.20.5 @@ -62,7 +57,6 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.50.0 go.opentelemetry.io/otel/metric v1.34.0 go.opentelemetry.io/otel/sdk/metric v1.28.0 - google.golang.org/api v0.180.0 google.golang.org/grpc v1.67.1 google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v2 v2.4.0 @@ -71,22 +65,14 @@ require ( ) require ( - cel.dev/expr v0.16.0 // indirect - cloud.google.com/go/auth v0.4.1 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.5.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/daaku/go.zipexe v1.0.2 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect - github.com/fasthttp/websocket v1.5.3 // indirect + github.com/fasthttp/websocket v1.5.8 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/google/s2a-go v0.1.7 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.4 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect @@ -119,15 +105,13 @@ require ( github.com/pion/turn/v4 v4.0.0 // indirect github.com/pion/webrtc/v4 v4.0.9 // indirect github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 // indirect - github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 // indirect github.com/shirou/gopsutil/v4 v4.24.7 // indirect github.com/wlynxg/anet v0.0.5 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect go.uber.org/mock v0.5.0 // indirect - golang.org/x/oauth2 v0.24.0 // indirect golang.org/x/time v0.8.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect ) require ( @@ -162,7 +146,7 @@ require ( github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/docker v27.1.1+incompatible github.com/docker/docker-credential-helpers v0.7.0 // indirect - github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-connections v0.5.0 github.com/docker/go-units v0.5.0 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect github.com/elastic/gosigar v0.14.3 // indirect diff --git a/go.sum b/go.sum index 6af7a14b..dae92fae 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,7 @@ -cel.dev/expr v0.16.0 h1:yloc84fytn4zmJX2GU3TkXGsaieaV7dQ057Qs4sIG2Y= -cel.dev/expr v0.16.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= -cloud.google.com/go/auth v0.4.1 h1:Z7YNIhlWRtrnKlZke7z3GMqzvuYzdc2z98F9D1NV5Hg= -cloud.google.com/go/auth v0.4.1/go.mod h1:QVBuVEKpCn4Zp58hzRGvL0tjRGU0YqdRTdCHM1IHnro= -cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= -cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= -cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= -cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= @@ -73,8 +65,6 @@ github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szN github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= -github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/glamour v0.7.0 h1:2BtKGZ4iVJCDfMF229EzbeR1QRKLWztO9dMtjmqZSng= @@ -84,8 +74,6 @@ github.com/chasefleming/elem-go v0.26.0/go.mod h1:hz73qILBIKnTgOujnSMtEj20/epI+f github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 h1:N+3sFI5GUjRKBi+i0TxYVST9h4Ie192jJWpHvthBBgg= -github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE= github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM= github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw= @@ -161,10 +149,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= -github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= -github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQtdek= -github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs= +github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8= +github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= @@ -177,6 +163,8 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20240626202019-c118733a29ad h1:dQ93Vd6i25o+zH9vvnZ8mu7jtJQ6jT3D+zE3V8Q49n0= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20240626202019-c118733a29ad/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= @@ -250,8 +238,6 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -280,18 +266,12 @@ github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OI github.com/google/pprof v0.0.0-20250208200701-d0013a598941 h1:43XjGa6toxLpeksjcxs1jIoIyr+vUfOqY2c6HB4bpoc= github.com/google/pprof v0.0.0-20250208200701-d0013a598941/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= -github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= -github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= -github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= -github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c h1:7lF+Vz0LqiRidnzC1Oq86fpX1q/iEv2KJdrCtttYjT4= github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -505,6 +485,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/mudler/edgevpn v0.30.1 h1:4yyhNFJX62NpRp50sxiyZE5E/sdAqEZX+aE5Mv7QS60= github.com/mudler/edgevpn v0.30.1/go.mod h1:IAJkkJ0oH3rwsSGOGTFT4UBYFqYuD/QyaKzTLB3P/eU= +github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA= +github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU= github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 h1:Qh6ghkMgTu6siFbTf7L3IszJmshMhXxNL4V+t7IIA6w= @@ -564,8 +546,6 @@ github.com/opencontainers/runtime-spec v1.2.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= -github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw= -github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0= github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg= github.com/otiai10/mint v1.6.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/otiai10/openaigo v1.7.0 h1:AOQcOjRRM57ABvz+aI2oJA/Qsz1AydKbdZAlGiKyCqg= @@ -681,8 +661,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sashabaranov/go-openai v1.26.2 h1:cVlQa3gn3eYqNXRW03pPlpy6zLG52EU4g0FrWXc0EFI= github.com/sashabaranov/go-openai v1.26.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= -github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= -github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= +github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 h1:KanIMPX0QdEdB4R3CiimCAbxFrhB3j7h0/OvpYGVQa8= +github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= github.com/schollz/progressbar/v3 v3.14.4 h1:W9ZrDSJk7eqmQhd3uxFNNcTr0QL+xuGNI9dEMrw0r74= github.com/schollz/progressbar/v3 v3.14.4/go.mod h1:aT3UQ7yGm+2ZjeXPqsjTenwL3ddUiuZ0kfQ/2tHlyNI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -925,8 +905,6 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= -golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1046,8 +1024,6 @@ gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= -google.golang.org/api v0.180.0 h1:M2D87Yo0rGBPWpo1orwfCLehUUL6E7/TYe5gvMQWDh4= -google.golang.org/api v0.180.0/go.mod h1:51AiyoEg1MJPSZ9zvklA8VnRILPXxn1iVen9v25XHAE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1060,7 +1036,6 @@ google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda h1:wu/KJm9KJwpfHWhkkZGohVC6KRrc1oJNr4jwtQMOQXw= -google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda/go.mod h1:g2LLCvCeCSir/JJSWosk19BR4NVxGqHUC6rxIRsd7Aw= google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 h1:T6rh4haD3GVYsgEfWExoCZA2o2FmbNyKpTuAxbEFPTg= google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:wp2WsuBYj6j8wUdo3ToZsdxxixbvQNAHqVJrTgi5E5M= google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 h1:QCqS/PdaHTSWGvupk2F/ehwHtGc0/GYkT+3GAcR1CCc= diff --git a/pkg/audio/audio.go b/pkg/audio/audio.go new file mode 100644 index 00000000..946d902f --- /dev/null +++ b/pkg/audio/audio.go @@ -0,0 +1,55 @@ +package audio + +// Copied from VoxInput + +import ( + "encoding/binary" + "io" +) + +// WAVHeader represents the WAV file header (44 bytes for PCM) +type WAVHeader struct { + // RIFF Chunk (12 bytes) + ChunkID [4]byte + ChunkSize uint32 + Format [4]byte + + // fmt Subchunk (16 bytes) + Subchunk1ID [4]byte + Subchunk1Size uint32 + AudioFormat uint16 + NumChannels uint16 + SampleRate uint32 + ByteRate uint32 + BlockAlign uint16 + BitsPerSample uint16 + + // data Subchunk (8 bytes) + Subchunk2ID [4]byte + Subchunk2Size uint32 +} + +func NewWAVHeader(pcmLen uint32) WAVHeader { + header := WAVHeader{ + ChunkID: [4]byte{'R', 'I', 'F', 'F'}, + Format: [4]byte{'W', 'A', 'V', 'E'}, + Subchunk1ID: [4]byte{'f', 'm', 't', ' '}, + Subchunk1Size: 16, // PCM = 16 bytes + AudioFormat: 1, // PCM + NumChannels: 1, // Mono + SampleRate: 16000, + ByteRate: 16000 * 2, // SampleRate * BlockAlign (mono, 2 bytes per sample) + BlockAlign: 2, // 16-bit = 2 bytes per sample + BitsPerSample: 16, + Subchunk2ID: [4]byte{'d', 'a', 't', 'a'}, + Subchunk2Size: pcmLen, + } + + header.ChunkSize = 36 + header.Subchunk2Size + + return header +} + +func (h *WAVHeader) Write(writer io.Writer) error { + return binary.Write(writer, binary.LittleEndian, h) +} diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 9f9f19b1..ac3fe757 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -35,9 +35,9 @@ type Backend interface { IsBusy() bool HealthCheck(ctx context.Context) (bool, error) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) - Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error + Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) diff --git a/pkg/sound/float32.go b/pkg/sound/float32.go new file mode 100644 index 00000000..f42a04e5 --- /dev/null +++ b/pkg/sound/float32.go @@ -0,0 +1,12 @@ +package sound + +import ( + "encoding/binary" + "math" +) + +func BytesFloat32(bytes []byte) float32 { + bits := binary.LittleEndian.Uint32(bytes) + float := math.Float32frombits(bits) + return float +} diff --git a/pkg/sound/int16.go b/pkg/sound/int16.go new file mode 100644 index 00000000..f56aa14f --- /dev/null +++ b/pkg/sound/int16.go @@ -0,0 +1,90 @@ +package sound + +import ( + "encoding/binary" + "math" +) + +/* + +MIT License + +Copyright (c) 2024 Xbozon + +*/ + +// calculateRMS16 calculates the root mean square of the audio buffer for int16 samples. +func CalculateRMS16(buffer []int16) float64 { + var sumSquares float64 + for _, sample := range buffer { + val := float64(sample) // Convert int16 to float64 for calculation + sumSquares += val * val + } + meanSquares := sumSquares / float64(len(buffer)) + return math.Sqrt(meanSquares) +} + +func ResampleInt16(input []int16, inputRate, outputRate int) []int16 { + // Calculate the resampling ratio + ratio := float64(inputRate) / float64(outputRate) + + // Calculate the length of the resampled output + outputLength := int(float64(len(input)) / ratio) + + // Allocate a slice for the resampled output + output := make([]int16, outputLength) + + // Perform linear interpolation for resampling + for i := 0; i < outputLength-1; i++ { + // Calculate the corresponding position in the input + pos := float64(i) * ratio + + // Calculate the indices of the surrounding input samples + indexBefore := int(pos) + indexAfter := indexBefore + 1 + if indexAfter >= len(input) { + indexAfter = len(input) - 1 + } + + // Calculate the fractional part of the position + frac := pos - float64(indexBefore) + + // Linearly interpolate between the two surrounding input samples + output[i] = int16((1-frac)*float64(input[indexBefore]) + frac*float64(input[indexAfter])) + } + + // Handle the last sample explicitly to avoid index out of range + output[outputLength-1] = input[len(input)-1] + + return output +} + +func ConvertInt16ToInt(input []int16) []int { + output := make([]int, len(input)) // Allocate a slice for the output + for i, value := range input { + output[i] = int(value) // Convert each int16 to int and assign it to the output slice + } + return output // Return the converted slice +} + +func BytesToInt16sLE(bytes []byte) []int16 { + // Ensure the byte slice length is even + if len(bytes)%2 != 0 { + panic("bytesToInt16sLE: input bytes slice has odd length, must be even") + } + + int16s := make([]int16, len(bytes)/2) + for i := 0; i < len(int16s); i++ { + int16s[i] = int16(bytes[2*i]) | int16(bytes[2*i+1])<<8 + } + return int16s +} + +func Int16toBytesLE(arr []int16) []byte { + le := binary.LittleEndian + result := make([]byte, 0, 2*len(arr)) + for _, val := range arr { + result = le.AppendUint16(result, uint16(val)) + } + return result +} diff --git a/pkg/utils/ffmpeg.go b/pkg/utils/ffmpeg.go index 68683370..03384fa2 100644 --- a/pkg/utils/ffmpeg.go +++ b/pkg/utils/ffmpeg.go @@ -5,6 +5,8 @@ import ( "os" "os/exec" "strings" + + "github.com/go-audio/wav" ) func ffmpegCommand(args []string) (string, error) { @@ -17,6 +19,21 @@ func ffmpegCommand(args []string) (string, error) { // AudioToWav converts audio to wav for transcribe. // TODO: use https://github.com/mccoyst/ogg? func AudioToWav(src, dst string) error { + if strings.HasSuffix(src, ".wav") { + f, err := os.Open(src) + if err != nil { + return fmt.Errorf("open: %w", err) + } + + dec := wav.NewDecoder(f) + dec.ReadInfo() + f.Close() + + if dec.BitDepth == 16 && dec.NumChans == 1 && dec.SampleRate == 16000 { + os.Rename(src, dst) + return nil + } + } commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} out, err := ffmpegCommand(commandArgs) if err != nil {