From 9a0982066fe0f98407d067a4f91dda7ed2b9c02b Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 9 Jan 2025 22:07:57 +0100 Subject: [PATCH] WIP - improve start and end of speech detection Signed-off-by: Ettore Di Giacinto --- core/http/endpoints/openai/realtime.go | 267 +++++++++++++------------ 1 file changed, 137 insertions(+), 130 deletions(-) diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 19ae0afe..4adc60c1 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -497,158 +497,165 @@ type VADState int const ( StateSilence VADState = iota StateSpeaking - StateTrailingSilence ) -// handle VAD (Voice Activity Detection) -func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) { +const ( + // tune these thresholds to taste + SpeechFramesThreshold = 3 // must see X consecutive speech results to confirm "start" + SilenceFramesThreshold = 5 // must see X consecutive silence results to confirm "end" +) +// handleVAD is a goroutine that listens for audio data from the client, +// runs VAD on the audio data, and commits utterances to the conversation +func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) - //var startListening time.Time - go func() { <-done cancel() }() - vadState := VADState(StateSilence) - segments := []*proto.VADSegment{} - timeListening := time.Now() + ticker := time.NewTicker(300 * time.Millisecond) + defer ticker.Stop() + + var ( + lastSegmentCount int + timeOfLastNewSeg time.Time + speaking bool + ) - // Implement VAD logic here - // For brevity, this is a placeholder - // When VAD detects end of speech, generate a response - // TODO: use session.ModelInterface to handle VAD and cut audio and detect when to process that for { select { case <-done: return - default: - // Check if there's audio data to process + case <-ticker.C: + // 1) Copy the entire buffer session.AudioBufferLock.Lock() + allAudio := make([]byte, len(session.InputAudioBuffer)) + copy(allAudio, session.InputAudioBuffer) + session.AudioBufferLock.Unlock() - if len(session.InputAudioBuffer) > 0 { - - if vadState == StateTrailingSilence { - log.Debug().Msgf("VAD detected speech that we can process") - - // Commit the audio buffer as a conversation item - item := &Item{ - ID: generateItemID(), - Object: "realtime.item", - Type: "message", - Status: "completed", - Role: "user", - Content: []ConversationContent{ - { - Type: "input_audio", - Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer), - }, - }, - } - - // Add item to conversation - conversation.Lock.Lock() - conversation.Items = append(conversation.Items, item) - conversation.Lock.Unlock() - - // Reset InputAudioBuffer - session.InputAudioBuffer = nil - session.AudioBufferLock.Unlock() - - // Send item.created event - sendEvent(c, OutgoingMessage{ - Type: "conversation.item.created", - Item: item, - }) - - vadState = StateSilence - segments = []*proto.VADSegment{} - // Generate a response - generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage) - continue - } - - adata := sound.BytesToInt16sLE(session.InputAudioBuffer) - - // Resample from 24kHz to 16kHz - // adata = sound.ResampleInt16(adata, 24000, 16000) - - soundIntBuffer := &audio.IntBuffer{ - Format: &audio.Format{SampleRate: 16000, NumChannels: 1}, - } - soundIntBuffer.Data = sound.ConvertInt16ToInt(adata) - - /* if len(adata) < 16000 { - log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer)) - session.AudioBufferLock.Unlock() - continue - } */ - float32Data := soundIntBuffer.AsFloat32Buffer().Data - - // TODO: testing wav decoding - // dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer)) - // buf, err := dec.FullPCMBuffer() - // if err != nil { - // //log.Error().Msgf("failed to process audio: %s", err.Error()) - // sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") - // session.AudioBufferLock.Unlock() - // continue - // } - - //float32Data = buf.AsFloat32Buffer().Data - - resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{ - Audio: float32Data, - }) - if err != nil { - log.Error().Msgf("failed to process audio: %s", err.Error()) - sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") - session.AudioBufferLock.Unlock() - continue - } - - if len(resp.Segments) == 0 { - log.Debug().Msg("VAD detected no speech activity") - log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer)) - if len(session.InputAudioBuffer) > 16000 { - session.InputAudioBuffer = nil - segments = []*proto.VADSegment{} - } - - log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer)) - } else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking { - // We have new segments, but we are still speaking - // We need to wait for the trailing silence - - segments = resp.Segments - - } else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking { - // We have the same number of segments, but we are still speaking - // We need to check if we are in this state for long enough, update the timer - - // Check if we have been listening for too long - if time.Since(timeListening) > sendToVADDelay { - vadState = StateTrailingSilence - } else { - - timeListening = timeListening.Add(time.Since(timeListening)) - } - } else { - log.Debug().Msg("VAD detected speech activity") - vadState = StateSpeaking - segments = resp.Segments - } - - session.AudioBufferLock.Unlock() - } else { - session.AudioBufferLock.Unlock() + // 2) If there's no audio at all, just continue + if len(allAudio) == 0 { + continue } + // 3) Run VAD on the entire audio so far + segments, err := runVAD(vadContext, session, allAudio) + if err != nil { + log.Error().Msgf("failed to process audio: %s", err.Error()) + sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") + // handle or log error, continue + continue + } + + segCount := len(segments) + + if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > 1*time.Second { + // no speech detected, and we haven't seen a new segment in > 1s + // clean up input + session.AudioBufferLock.Lock() + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() + log.Debug().Msgf("Detected silence for a while, clearing audio buffer") + continue + } + + // 4) If we see more segments than before => "new speech" + if segCount > lastSegmentCount { + speaking = true + lastSegmentCount = segCount + timeOfLastNewSeg = time.Now() + log.Debug().Msgf("Detected new speech segment") + } + + // 5) If speaking, but we haven't seen a new segment in > 1s => finalize + if speaking && time.Since(timeOfLastNewSeg) > 1*time.Second { + log.Debug().Msgf("Detected end of speech segment") + // user has presumably stopped talking + commitUtterance(allAudio, cfg, evaluator, session, conv, c) + // reset state + speaking = false + lastSegmentCount = 0 + } } } } +func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) { + if len(utt) == 0 { + return + } + // Commit logic: create item, broadcast item.created, etc. + item := &Item{ + ID: generateItemID(), + Object: "realtime.item", + Type: "message", + Status: "completed", + Role: "user", + Content: []ConversationContent{ + { + Type: "input_audio", + Audio: base64.StdEncoding.EncodeToString(utt), + }, + }, + } + conv.Lock.Lock() + conv.Items = append(conv.Items, item) + conv.Lock.Unlock() + + sendEvent(c, OutgoingMessage{ + Type: "conversation.item.created", + Item: item, + }) + + // Optionally trigger the response generation + generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage) +} + +// runVAD is a helper that calls your model's VAD method, returning +// true if it detects speech, false if it detects silence +func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) { + + adata := sound.BytesToInt16sLE(chunk) + + // Resample from 24kHz to 16kHz + // adata = sound.ResampleInt16(adata, 24000, 16000) + + soundIntBuffer := &audio.IntBuffer{ + Format: &audio.Format{SampleRate: 16000, NumChannels: 1}, + } + soundIntBuffer.Data = sound.ConvertInt16ToInt(adata) + + /* if len(adata) < 16000 { + log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer)) + session.AudioBufferLock.Unlock() + continue + } */ + float32Data := soundIntBuffer.AsFloat32Buffer().Data + + resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{ + Audio: float32Data, + }) + if err != nil { + return nil, err + } + + // TODO: testing wav decoding + // dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer)) + // buf, err := dec.FullPCMBuffer() + // if err != nil { + // //log.Error().Msgf("failed to process audio: %s", err.Error()) + // sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") + // session.AudioBufferLock.Unlock() + // continue + // } + + //float32Data = buf.AsFloat32Buffer().Data + + // If resp.Segments is empty => no speech + return resp.Segments, nil +} + // Function to generate a response based on the conversation func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {