Improve audio detection

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-01-09 19:37:18 +01:00
parent 01aace3017
commit 30e3c47598

View file

@ -13,7 +13,6 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
@ -138,6 +137,8 @@ func registerRealtime(application *application.Application) func(c *websocket.Co
model = "gpt-4o" model = "gpt-4o"
} }
log.Info().Msgf("New session with model: %s", model)
sessionID := generateSessionID() sessionID := generateSessionID()
session := &Session{ session := &Session{
ID: sessionID, ID: sessionID,
@ -487,9 +488,16 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo
} }
const ( const (
minMicVolume = 450 minMicVolume = 450
sendToVADDelay = time.Second sendToVADDelay = time.Second
maxWhisperSegmentDuration = time.Second * 15 )
type VADState int
const (
StateSilence VADState = iota
StateSpeaking
StateTrailingSilence
) )
// handle VAD (Voice Activity Detection) // handle VAD (Voice Activity Detection)
@ -503,7 +511,8 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
cancel() cancel()
}() }()
audioDetected := false vadState := VADState(StateSilence)
segments := []*proto.VADSegment{}
timeListening := time.Now() timeListening := time.Now()
// Implement VAD logic here // Implement VAD logic here
@ -520,15 +529,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
if len(session.InputAudioBuffer) > 0 { if len(session.InputAudioBuffer) > 0 {
if audioDetected && time.Since(timeListening) < maxWhisperSegmentDuration { if vadState == StateTrailingSilence {
log.Debug().Msgf("VAD detected speech, but still listening")
// audioDetected = false
// keep listening
session.AudioBufferLock.Unlock()
continue
}
if audioDetected {
log.Debug().Msgf("VAD detected speech that we can process") log.Debug().Msgf("VAD detected speech that we can process")
// Commit the audio buffer as a conversation item // Commit the audio buffer as a conversation item
@ -561,7 +562,8 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
Item: item, Item: item,
}) })
audioDetected = false vadState = StateSilence
segments = []*proto.VADSegment{}
// Generate a response // Generate a response
generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage) generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage)
continue continue
@ -570,7 +572,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
adata := sound.BytesToInt16sLE(session.InputAudioBuffer) adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
// Resample from 24kHz to 16kHz // Resample from 24kHz to 16kHz
adata = sound.ResampleInt16(adata, 24000, 16000) // adata = sound.ResampleInt16(adata, 24000, 16000)
soundIntBuffer := &audio.IntBuffer{ soundIntBuffer := &audio.IntBuffer{
Format: &audio.Format{SampleRate: 16000, NumChannels: 1}, Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
@ -582,9 +584,20 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
session.AudioBufferLock.Unlock() session.AudioBufferLock.Unlock()
continue continue
} */ } */
float32Data := soundIntBuffer.AsFloat32Buffer().Data float32Data := soundIntBuffer.AsFloat32Buffer().Data
// TODO: testing wav decoding
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
// buf, err := dec.FullPCMBuffer()
// if err != nil {
// //log.Error().Msgf("failed to process audio: %s", err.Error())
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
// session.AudioBufferLock.Unlock()
// continue
// }
//float32Data = buf.AsFloat32Buffer().Data
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{ resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
Audio: float32Data, Audio: float32Data,
}) })
@ -598,20 +611,34 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
if len(resp.Segments) == 0 { if len(resp.Segments) == 0 {
log.Debug().Msg("VAD detected no speech activity") log.Debug().Msg("VAD detected no speech activity")
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer)) log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
if len(session.InputAudioBuffer) > 16000 {
if !audioDetected {
session.InputAudioBuffer = nil session.InputAudioBuffer = nil
segments = []*proto.VADSegment{}
} }
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer)) log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
} else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking {
// We have new segments, but we are still speaking
// We need to wait for the trailing silence
session.AudioBufferLock.Unlock() segments = resp.Segments
continue
}
if !audioDetected { } else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking {
timeListening = time.Now() // We have the same number of segments, but we are still speaking
// We need to check if we are in this state for long enough, update the timer
// Check if we have been listening for too long
if time.Since(timeListening) > sendToVADDelay {
vadState = StateTrailingSilence
} else {
timeListening = timeListening.Add(time.Since(timeListening))
}
} else {
log.Debug().Msg("VAD detected speech activity")
vadState = StateSpeaking
segments = resp.Segments
} }
audioDetected = true
session.AudioBufferLock.Unlock() session.AudioBufferLock.Unlock()
} else { } else {
@ -843,101 +870,104 @@ func processTextResponse(config *config.BackendConfig, session *Session, prompt
// Replace this with actual model inference logic using session.Model and prompt // Replace this with actual model inference logic using session.Model and prompt
// For example, the model might return a special token or JSON indicating a function call // For example, the model might return a special token or JSON indicating a function call
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil) /*
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) { result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !shouldUseFn { if !shouldUseFn {
// no function is called, just reply and use stop as finish reason // no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return return
} }
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}
if len(input.Tools) > 0 { textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
toolChoice.FinishReason = "tool_calls" s = functions.CleanupLLMResult(s, config.FunctionsConfig)
} results := functions.ParseFunctionCall(s, config.FunctionsConfig)
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
}
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}
for _, ss := range results {
name, args := ss.Name, ss.Arguments
if len(input.Tools) > 0 { if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into toolChoice.FinishReason = "tool_calls"
// a single response choice with all the tools }
toolChoice.Message.Content = textContentToReturn
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, for _, ss := range results {
schema.ToolCall{ name, args := ss.Name, ss.Arguments
ID: id, if len(input.Tools) > 0 {
Type: "function", // If we are using tools, we condense the function calls into
FunctionCall: schema.FunctionCall{ // a single response choice with all the tools
Name: name, toolChoice.Message.Content = textContentToReturn
Arguments: args, toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
}, },
}, )
) } else {
} else { // otherwise we return more choices directly
// otherwise we return more choices directly *c = append(*c, schema.Choice{
*c = append(*c, schema.Choice{ FinishReason: "function_call",
FinishReason: "function_call", Message: &schema.Message{
Message: &schema.Message{ Role: "assistant",
Role: "assistant", Content: &textContentToReturn,
Content: &textContentToReturn, FunctionCall: map[string]interface{}{
FunctionCall: map[string]interface{}{ "name": name,
"name": name, "arguments": args,
"arguments": args, },
}, },
}, })
}) }
}
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
} }
} }
if len(input.Tools) > 0 { }, nil)
// we need to append our result if we are using tools if err != nil {
*c = append(*c, toolChoice) return err
}
} }
}, nil) resp := &schema.OpenAIResponse{
if err != nil { ID: id,
return err Created: created,
} Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
resp := &schema.OpenAIResponse{ // Return the prediction in the response body
ID: id, return c.JSON(resp)
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body */
return c.JSON(resp)
// TODO: use session.ModelInterface... // TODO: use session.ModelInterface...
// Simulate a function call // Simulate a function call