diff --git a/backend/backend.proto b/backend/backend.proto index cdf09bf2..1da19859 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -162,6 +162,7 @@ message Reply { int32 prompt_tokens = 3; double timing_prompt_processing = 4; double timing_token_generation = 5; + bytes audio = 6; } message GrammarTrigger { diff --git a/backend/go/vad/silero/vad.go b/backend/go/vad/silero/vad.go index 5a164d2a..31b3c897 100644 --- a/backend/go/vad/silero/vad.go +++ b/backend/go/vad/silero/vad.go @@ -21,8 +21,8 @@ func (vad *VAD) Load(opts *pb.ModelOptions) error { SampleRate: 16000, //WindowSize: 1024, Threshold: 0.5, - MinSilenceDurationMs: 0, - SpeechPadMs: 0, + MinSilenceDurationMs: 100, + SpeechPadMs: 30, }) if err != nil { return fmt.Errorf("create silero detector: %w", err) diff --git a/core/backend/llm.go b/core/backend/llm.go index 57e2ae35..f36a568a 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -22,8 +22,9 @@ import ( ) type LLMResponse struct { - Response string // should this be []byte? - Usage TokenUsage + Response string // should this be []byte? + Usage TokenUsage + AudioOutput string } type TokenUsage struct { diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 5c436400..b2ccfe90 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -37,6 +37,7 @@ type BackendConfig struct { TemplateConfig TemplateConfig `yaml:"template"` KnownUsecaseStrings []string `yaml:"known_usecases"` KnownUsecases *BackendConfigUsecases `yaml:"-"` + Pipeline Pipeline `yaml:"pipeline"` PromptStrings, InputStrings []string `yaml:"-"` InputToken [][]int `yaml:"-"` @@ -72,6 +73,18 @@ type BackendConfig struct { Options []string `yaml:"options"` } +// Pipeline defines other models to use for audio-to-audio +type Pipeline struct { + TTS string `yaml:"tts"` + LLM string `yaml:"llm"` + Transcription string `yaml:"transcription"` + VAD string `yaml:"vad"` +} + +func (p Pipeline) IsNotConfigured() bool { + return p.LLM == "" || p.TTS == "" || p.Transcription == "" +} + type File struct { Filename string `yaml:"filename" json:"filename"` SHA256 string `yaml:"sha256" json:"sha256"` diff --git a/core/http/app.go b/core/http/app.go index 0edd7ef1..68e4c264 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -9,6 +9,7 @@ import ( "path/filepath" "github.com/dave-gray101/v2keyauth" + "github.com/gofiber/websocket/v2" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/core/http/endpoints/localai" @@ -91,6 +92,7 @@ func API(application *application.Application) (*fiber.App, error) { router.Use(middleware.StripPathPrefix()) +<<<<<<< HEAD if application.ApplicationConfig().MachineTag != "" { router.Use(func(c *fiber.Ctx) error { c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag) @@ -98,6 +100,16 @@ func API(application *application.Application) (*fiber.App, error) { return c.Next() }) } +======= + router.Use("/v1/realtime", func(c *fiber.Ctx) error { + if websocket.IsWebSocketUpgrade(c) { + // Returns true if the client requested upgrade to the WebSocket protocol + return c.Next() + } + + return nil + }) +>>>>>>> 43463868 (feat(realtime): Initial Realtime API implementation) router.Hooks().OnListen(func(listenData fiber.ListenData) error { scheme := "http" diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go new file mode 100644 index 00000000..6f6b774d --- /dev/null +++ b/core/http/endpoints/openai/realtime.go @@ -0,0 +1,1136 @@ +package openai + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/go-audio/wav" + + "github.com/go-audio/audio" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/websocket/v2" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/sound" + "github.com/mudler/LocalAI/pkg/templates" + + "google.golang.org/grpc" + + "github.com/rs/zerolog/log" +) + +// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result +// If the model support instead audio-to-audio, we will use the specific gRPC calls instead + +// Session represents a single WebSocket connection and its state +type Session struct { + ID string + Model string + Voice string + TurnDetection *TurnDetection `json:"turn_detection"` // "server_vad" or "none" + Functions functions.Functions + Conversations map[string]*Conversation + InputAudioBuffer []byte + AudioBufferLock sync.Mutex + Instructions string + DefaultConversationID string + ModelInterface Model +} + +type TurnDetection struct { + Type string `json:"type"` +} + +// FunctionCall represents a function call initiated by the model +type FunctionCall struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// Conversation represents a conversation with a list of items +type Conversation struct { + ID string + Items []*Item + Lock sync.Mutex +} + +// Item represents a message, function_call, or function_call_output +type Item struct { + ID string `json:"id"` + Object string `json:"object"` + Type string `json:"type"` // "message", "function_call", "function_call_output" + Status string `json:"status"` + Role string `json:"role"` + Content []ConversationContent `json:"content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +// ConversationContent represents the content of an item +type ConversationContent struct { + Type string `json:"type"` // "input_text", "input_audio", "text", "audio", etc. + Audio string `json:"audio,omitempty"` + Text string `json:"text,omitempty"` + // Additional fields as needed +} + +// Define the structures for incoming messages +type IncomingMessage struct { + Type string `json:"type"` + Session json.RawMessage `json:"session,omitempty"` + Item json.RawMessage `json:"item,omitempty"` + Audio string `json:"audio,omitempty"` + Response json.RawMessage `json:"response,omitempty"` + Error *ErrorMessage `json:"error,omitempty"` + // Other fields as needed +} + +// ErrorMessage represents an error message sent to the client +type ErrorMessage struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` + Param string `json:"param,omitempty"` + EventID string `json:"event_id,omitempty"` +} + +// Define a structure for outgoing messages +type OutgoingMessage struct { + Type string `json:"type"` + Session *Session `json:"session,omitempty"` + Conversation *Conversation `json:"conversation,omitempty"` + Item *Item `json:"item,omitempty"` + Content string `json:"content,omitempty"` + Audio string `json:"audio,omitempty"` + Error *ErrorMessage `json:"error,omitempty"` +} + +// Map to store sessions (in-memory) +var sessions = make(map[string]*Session) +var sessionLock sync.Mutex + +// TODO: implement interface as we start to define usages +type Model interface { + VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) + Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) + PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error +} + +func Realtime(application *application.Application) fiber.Handler { + return websocket.New(registerRealtime(application)) +} + +func registerRealtime(application *application.Application) func(c *websocket.Conn) { + return func(c *websocket.Conn) { + + evaluator := application.TemplatesEvaluator() + log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String()) + + model := c.Params("model") + if model == "" { + model = "gpt-4o" + } + + log.Info().Msgf("New session with model: %s", model) + + sessionID := generateSessionID() + session := &Session{ + ID: sessionID, + Model: model, // default model + Voice: "alloy", // default voice + TurnDetection: &TurnDetection{Type: "none"}, + Conversations: make(map[string]*Conversation), + } + + // Create a default conversation + conversationID := generateConversationID() + conversation := &Conversation{ + ID: conversationID, + Items: []*Item{}, + } + session.Conversations[conversationID] = conversation + session.DefaultConversationID = conversationID + + cfg, err := application.BackendLoader().LoadBackendConfigFileByName(model, application.ModelLoader().ModelPath) + if err != nil { + log.Error().Msgf("failed to load model (no config): %s", err.Error()) + sendError(c, "model_load_error", "Failed to load model (no config)", "", "") + return + } + + m, err := newModel( + cfg, + application.BackendLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + model, + ) + if err != nil { + log.Error().Msgf("failed to load model: %s", err.Error()) + sendError(c, "model_load_error", "Failed to load model", "", "") + return + } + session.ModelInterface = m + + // Store the session + sessionLock.Lock() + sessions[sessionID] = session + sessionLock.Unlock() + + // Send session.created and conversation.created events to the client + sendEvent(c, OutgoingMessage{ + Type: "session.created", + Session: session, + }) + sendEvent(c, OutgoingMessage{ + Type: "conversation.created", + Conversation: conversation, + }) + + var ( + mt int + msg []byte + wg sync.WaitGroup + done = make(chan struct{}) + ) + + var vadServerStarted bool + + for { + if mt, msg, err = c.ReadMessage(); err != nil { + log.Error().Msgf("read: %s", err.Error()) + break + } + + // Parse the incoming message + var incomingMsg IncomingMessage + if err := json.Unmarshal(msg, &incomingMsg); err != nil { + log.Error().Msgf("invalid json: %s", err.Error()) + sendError(c, "invalid_json", "Invalid JSON format", "", "") + continue + } + + switch incomingMsg.Type { + case "session.update": + log.Printf("recv: %s", msg) + + // Update session configurations + var sessionUpdate Session + if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil { + log.Error().Msgf("failed to unmarshal 'session.update': %s", err.Error()) + sendError(c, "invalid_session_update", "Invalid session update format", "", "") + continue + } + if err := updateSession( + session, + &sessionUpdate, + application.BackendLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + ); err != nil { + log.Error().Msgf("failed to update session: %s", err.Error()) + sendError(c, "session_update_error", "Failed to update session", "", "") + continue + } + + // Acknowledge the session update + sendEvent(c, OutgoingMessage{ + Type: "session.updated", + Session: session, + }) + + if session.TurnDetection.Type == "server_vad" && !vadServerStarted { + log.Debug().Msg("Starting VAD goroutine...") + wg.Add(1) + go func() { + defer wg.Done() + conversation := session.Conversations[session.DefaultConversationID] + handleVAD(cfg, evaluator, session, conversation, c, done) + }() + vadServerStarted = true + } else if vadServerStarted { + log.Debug().Msg("Stopping VAD goroutine...") + + wg.Add(-1) + go func() { + done <- struct{}{} + }() + vadServerStarted = false + } + case "input_audio_buffer.append": + // Handle 'input_audio_buffer.append' + if incomingMsg.Audio == "" { + log.Error().Msg("Audio data is missing in 'input_audio_buffer.append'") + sendError(c, "missing_audio_data", "Audio data is missing", "", "") + continue + } + + // Decode base64 audio data + decodedAudio, err := base64.StdEncoding.DecodeString(incomingMsg.Audio) + if err != nil { + log.Error().Msgf("failed to decode audio data: %s", err.Error()) + sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "") + continue + } + + // Append to InputAudioBuffer + session.AudioBufferLock.Lock() + session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) + session.AudioBufferLock.Unlock() + + case "input_audio_buffer.commit": + log.Printf("recv: %s", msg) + + // Commit the audio buffer to the conversation as a new item + item := &Item{ + ID: generateItemID(), + Object: "realtime.item", + Type: "message", + Status: "completed", + Role: "user", + Content: []ConversationContent{ + { + Type: "input_audio", + Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer), + }, + }, + } + + // Add item to conversation + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, item) + conversation.Lock.Unlock() + + // Reset InputAudioBuffer + session.AudioBufferLock.Lock() + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() + + // Send item.created event + sendEvent(c, OutgoingMessage{ + Type: "conversation.item.created", + Item: item, + }) + + case "conversation.item.create": + log.Printf("recv: %s", msg) + + // Handle creating new conversation items + var item Item + if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { + log.Error().Msgf("failed to unmarshal 'conversation.item.create': %s", err.Error()) + sendError(c, "invalid_item", "Invalid item format", "", "") + continue + } + + // Generate item ID and set status + item.ID = generateItemID() + item.Object = "realtime.item" + item.Status = "completed" + + // Add item to conversation + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, &item) + conversation.Lock.Unlock() + + // Send item.created event + sendEvent(c, OutgoingMessage{ + Type: "conversation.item.created", + Item: &item, + }) + + case "conversation.item.delete": + log.Printf("recv: %s", msg) + + // Handle deleting conversation items + // Implement deletion logic as needed + + case "response.create": + log.Printf("recv: %s", msg) + + // Handle generating a response + var responseCreate ResponseCreate + if len(incomingMsg.Response) > 0 { + if err := json.Unmarshal(incomingMsg.Response, &responseCreate); err != nil { + log.Error().Msgf("failed to unmarshal 'response.create' response object: %s", err.Error()) + sendError(c, "invalid_response_create", "Invalid response create format", "", "") + continue + } + } + + // Update session functions if provided + if len(responseCreate.Functions) > 0 { + session.Functions = responseCreate.Functions + } + + // Generate a response based on the conversation history + wg.Add(1) + go func() { + defer wg.Done() + generateResponse(cfg, evaluator, session, conversation, responseCreate, c, mt) + }() + + case "conversation.item.update": + log.Printf("recv: %s", msg) + + // Handle function_call_output from the client + var item Item + if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { + log.Error().Msgf("failed to unmarshal 'conversation.item.update': %s", err.Error()) + sendError(c, "invalid_item_update", "Invalid item update format", "", "") + continue + } + + // Add the function_call_output item to the conversation + item.ID = generateItemID() + item.Object = "realtime.item" + item.Status = "completed" + + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, &item) + conversation.Lock.Unlock() + + // Send item.updated event + sendEvent(c, OutgoingMessage{ + Type: "conversation.item.updated", + Item: &item, + }) + + case "response.cancel": + log.Printf("recv: %s", msg) + + // Handle cancellation of ongoing responses + // Implement cancellation logic as needed + + default: + log.Error().Msgf("unknown message type: %s", incomingMsg.Type) + sendError(c, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") + } + } + + // Close the done channel to signal goroutines to exit + close(done) + wg.Wait() + + // Remove the session from the sessions map + sessionLock.Lock() + delete(sessions, sessionID) + sessionLock.Unlock() + } +} + +// Helper function to send events to the client +func sendEvent(c *websocket.Conn, event OutgoingMessage) { + eventBytes, err := json.Marshal(event) + if err != nil { + log.Error().Msgf("failed to marshal event: %s", err.Error()) + return + } + if err = c.WriteMessage(websocket.TextMessage, eventBytes); err != nil { + log.Error().Msgf("write: %s", err.Error()) + } +} + +// Helper function to send errors to the client +func sendError(c *websocket.Conn, code, message, param, eventID string) { + errorEvent := OutgoingMessage{ + Type: "error", + Error: &ErrorMessage{ + Type: "error", + Code: code, + Message: message, + Param: param, + EventID: eventID, + }, + } + sendEvent(c, errorEvent) +} + +// Function to update session configurations +func updateSession(session *Session, update *Session, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { + sessionLock.Lock() + defer sessionLock.Unlock() + + if update.Model != "" { + cfg, err := cl.LoadBackendConfigFileByName(update.Model, ml.ModelPath) + if err != nil { + return err + } + + m, err := newModel(cfg, cl, ml, appConfig, update.Model) + if err != nil { + return err + } + session.ModelInterface = m + session.Model = update.Model + } + + if update.Voice != "" { + session.Voice = update.Voice + } + if update.TurnDetection != nil && update.TurnDetection.Type != "" { + session.TurnDetection.Type = update.TurnDetection.Type + } + if update.Instructions != "" { + session.Instructions = update.Instructions + } + if update.Functions != nil { + session.Functions = update.Functions + } + + return nil +} + +const ( + sendToVADDelay = 2 * time.Second + silenceThreshold = 2 * time.Second +) + +// handleVAD is a goroutine that listens for audio data from the client, +// runs VAD on the audio data, and commits utterances to the conversation +func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) { + vadContext, cancel := context.WithCancel(context.Background()) + go func() { + <-done + cancel() + }() + + ticker := time.NewTicker(300 * time.Millisecond) + defer ticker.Stop() + + var ( + lastSegmentCount int + timeOfLastNewSeg time.Time + speaking bool + ) + + for { + select { + case <-done: + return + case <-ticker.C: + // 1) Copy the entire buffer + session.AudioBufferLock.Lock() + allAudio := make([]byte, len(session.InputAudioBuffer)) + copy(allAudio, session.InputAudioBuffer) + session.AudioBufferLock.Unlock() + + // 2) If there's no audio at all, or just too small samples, just continue + if len(allAudio) == 0 || len(allAudio) < 32000 { + continue + } + + // 3) Run VAD on the entire audio so far + segments, err := runVAD(vadContext, session, allAudio) + if err != nil { + if err.Error() == "unexpected speech end" { + log.Debug().Msg("VAD cancelled") + continue + } + log.Error().Msgf("failed to process audio: %s", err.Error()) + sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") + // handle or log error, continue + continue + } + + segCount := len(segments) + + if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > silenceThreshold { + // no speech detected, and we haven't seen a new segment in > 1s + // clean up input + session.AudioBufferLock.Lock() + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() + log.Debug().Msgf("Detected silence for a while, clearing audio buffer") + continue + } + + // 4) If we see more segments than before => "new speech" + if segCount > lastSegmentCount { + speaking = true + lastSegmentCount = segCount + timeOfLastNewSeg = time.Now() + log.Debug().Msgf("Detected new speech segment") + } + + // 5) If speaking, but we haven't seen a new segment in > 1s => finalize + if speaking && time.Since(timeOfLastNewSeg) > sendToVADDelay { + log.Debug().Msgf("Detected end of speech segment") + session.AudioBufferLock.Lock() + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() + // user has presumably stopped talking + commitUtterance(allAudio, cfg, evaluator, session, conv, c) + // reset state + speaking = false + lastSegmentCount = 0 + } + } + } +} + +func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) { + if len(utt) == 0 { + return + } + // Commit logic: create item, broadcast item.created, etc. + item := &Item{ + ID: generateItemID(), + Object: "realtime.item", + Type: "message", + Status: "completed", + Role: "user", + Content: []ConversationContent{ + { + Type: "input_audio", + Audio: base64.StdEncoding.EncodeToString(utt), + }, + }, + } + conv.Lock.Lock() + conv.Items = append(conv.Items, item) + conv.Lock.Unlock() + + sendEvent(c, OutgoingMessage{ + Type: "conversation.item.created", + Item: item, + }) + + // save chunk to disk + f, err := os.CreateTemp("", "audio-*.wav") + if err != nil { + log.Error().Msgf("failed to create temp file: %s", err.Error()) + return + } + defer f.Close() + //defer os.Remove(f.Name()) + log.Debug().Msgf("Writing to %s\n", f.Name()) + + f.Write(utt) + f.Sync() + + // trigger the response generation + generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage) +} + +// runVAD is a helper that calls the model's VAD method, returning +// true if it detects speech, false if it detects silence +func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) { + + adata := sound.BytesToInt16sLE(chunk) + + // Resample from 24kHz to 16kHz + adata = sound.ResampleInt16(adata, 24000, 16000) + + dec := wav.NewDecoder(bytes.NewReader(chunk)) + dur, err := dec.Duration() + if err != nil { + fmt.Printf("failed to get duration: %s\n", err) + } + fmt.Printf("duration: %s\n", dur) + + soundIntBuffer := &audio.IntBuffer{ + Format: &audio.Format{SampleRate: 16000, NumChannels: 1}, + } + soundIntBuffer.Data = sound.ConvertInt16ToInt(adata) + + /* if len(adata) < 16000 { + log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer)) + session.AudioBufferLock.Unlock() + continue + } */ + float32Data := soundIntBuffer.AsFloat32Buffer().Data + + resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{ + Audio: float32Data, + }) + if err != nil { + return nil, err + } + + // TODO: testing wav decoding + // dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer)) + // buf, err := dec.FullPCMBuffer() + // if err != nil { + // //log.Error().Msgf("failed to process audio: %s", err.Error()) + // sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") + // session.AudioBufferLock.Unlock() + // continue + // } + + //float32Data = buf.AsFloat32Buffer().Data + + // If resp.Segments is empty => no speech + return resp.Segments, nil +} + +// Function to generate a response based on the conversation +func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) { + + log.Debug().Msg("Generating realtime response...") + + // Compile the conversation history + conversation.Lock.Lock() + var conversationHistory []schema.Message + var latestUserAudio string + for _, item := range conversation.Items { + for _, content := range item.Content { + switch content.Type { + case "input_text", "text": + conversationHistory = append(conversationHistory, schema.Message{ + Role: item.Role, + StringContent: content.Text, + Content: content.Text, + }) + case "input_audio": + // We do not to turn to text here the audio result. + // When generating it later on from the LLM, + // we will also generate text and return it and store it in the conversation + // Here we just want to get the user audio if there is any as a new input for the conversation. + if item.Role == "user" { + latestUserAudio = content.Audio + } + } + } + } + + conversation.Lock.Unlock() + + var generatedText string + var generatedAudio []byte + var functionCall *FunctionCall + var err error + + if latestUserAudio != "" { + // Process the latest user audio input + decodedAudio, err := base64.StdEncoding.DecodeString(latestUserAudio) + if err != nil { + log.Error().Msgf("failed to decode latest user audio: %s", err.Error()) + sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "") + return + } + + // Process the audio input and generate a response + generatedText, generatedAudio, functionCall, err = processAudioResponse(session, decodedAudio) + if err != nil { + log.Error().Msgf("failed to process audio response: %s", err.Error()) + sendError(c, "processing_error", "Failed to generate audio response", "", "") + return + } + } else { + + if session.Instructions != "" { + conversationHistory = append([]schema.Message{{ + Role: "system", + StringContent: session.Instructions, + Content: session.Instructions, + }}, conversationHistory...) + } + + funcs := session.Functions + shouldUseFn := len(funcs) > 0 && config.ShouldUseFunctions() + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName + } + + if (!config.FunctionsConfig.GrammarConfig.NoGrammar) && shouldUseFn { + noActionGrammar := functions.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) + g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...) + if err == nil { + config.Grammar = g + } + } + + // Generate a response based on text conversation history + prompt := evaluator.TemplateMessages(conversationHistory, config, funcs, shouldUseFn) + + generatedText, functionCall, err = processTextResponse(config, session, prompt) + if err != nil { + log.Error().Msgf("failed to process text response: %s", err.Error()) + sendError(c, "processing_error", "Failed to generate text response", "", "") + return + } + log.Debug().Any("text", generatedText).Msg("Generated text response") + } + + if functionCall != nil { + // The model wants to call a function + // Create a function_call item and send it to the client + item := &Item{ + ID: generateItemID(), + Object: "realtime.item", + Type: "function_call", + Status: "completed", + Role: "assistant", + FunctionCall: functionCall, + } + + // Add item to conversation + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, item) + conversation.Lock.Unlock() + + // Send item.created event + sendEvent(c, OutgoingMessage{ + Type: "conversation.item.created", + Item: item, + }) + + // Optionally, you can generate a message to the user indicating the function call + // For now, we'll assume the client handles the function call and may trigger another response + + } else { + // Send response.stream messages + if generatedAudio != nil { + // If generatedAudio is available, send it as audio + encodedAudio := base64.StdEncoding.EncodeToString(generatedAudio) + outgoingMsg := OutgoingMessage{ + Type: "response.stream", + Audio: encodedAudio, + } + sendEvent(c, outgoingMsg) + } else { + // Send text response (could be streamed in chunks) + chunks := splitResponseIntoChunks(generatedText) + for _, chunk := range chunks { + outgoingMsg := OutgoingMessage{ + Type: "response.stream", + Content: chunk, + } + sendEvent(c, outgoingMsg) + } + } + + // Send response.done message + sendEvent(c, OutgoingMessage{ + Type: "response.done", + }) + + // Add the assistant's response to the conversation + content := []ConversationContent{} + if generatedAudio != nil { + content = append(content, ConversationContent{ + Type: "audio", + Audio: base64.StdEncoding.EncodeToString(generatedAudio), + }) + // Optionally include a text transcript + if generatedText != "" { + content = append(content, ConversationContent{ + Type: "text", + Text: generatedText, + }) + } + } else { + content = append(content, ConversationContent{ + Type: "text", + Text: generatedText, + }) + } + + item := &Item{ + ID: generateItemID(), + Object: "realtime.item", + Type: "message", + Status: "completed", + Role: "assistant", + Content: content, + } + + // Add item to conversation + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, item) + conversation.Lock.Unlock() + + // Send item.created event + sendEvent(c, OutgoingMessage{ + Type: "conversation.item.created", + Item: item, + }) + + log.Debug().Any("item", item).Msg("Realtime response sent") + } +} + +// Function to process text response and detect function calls +func processTextResponse(config *config.BackendConfig, session *Session, prompt string) (string, *FunctionCall, error) { + + // Placeholder implementation + // Replace this with actual model inference logic using session.Model and prompt + // For example, the model might return a special token or JSON indicating a function call + + /* + predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil) + + result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) { + if !shouldUseFn { + // no function is called, just reply and use stop as finish reason + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + return + } + + textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig) + s = functions.CleanupLLMResult(s, config.FunctionsConfig) + results := functions.ParseFunctionCall(s, config.FunctionsConfig) + log.Debug().Msgf("Text content to return: %s", textContentToReturn) + noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0 + + switch { + case noActionsToRun: + result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput) + if err != nil { + log.Error().Err(err).Msg("error handling question") + return + } + *c = append(*c, schema.Choice{ + Message: &schema.Message{Role: "assistant", Content: &result}}) + default: + toolChoice := schema.Choice{ + Message: &schema.Message{ + Role: "assistant", + }, + } + + if len(input.Tools) > 0 { + toolChoice.FinishReason = "tool_calls" + } + + for _, ss := range results { + name, args := ss.Name, ss.Arguments + if len(input.Tools) > 0 { + // If we are using tools, we condense the function calls into + // a single response choice with all the tools + toolChoice.Message.Content = textContentToReturn + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + ) + } else { + // otherwise we return more choices directly + *c = append(*c, schema.Choice{ + FinishReason: "function_call", + Message: &schema.Message{ + Role: "assistant", + Content: &textContentToReturn, + FunctionCall: map[string]interface{}{ + "name": name, + "arguments": args, + }, + }, + }) + } + } + + if len(input.Tools) > 0 { + // we need to append our result if we are using tools + *c = append(*c, toolChoice) + } + } + + }, nil) + if err != nil { + return err + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) + + */ + + // TODO: use session.ModelInterface... + // Simulate a function call + if strings.Contains(prompt, "weather") { + functionCall := &FunctionCall{ + Name: "get_weather", + Arguments: map[string]interface{}{ + "location": "New York", + "scale": "celsius", + }, + } + return "", functionCall, nil + } + + // Otherwise, return a normal text response + return "This is a generated response based on the conversation.", nil, nil +} + +// Function to process audio response and detect function calls +func processAudioResponse(session *Session, audioData []byte) (string, []byte, *FunctionCall, error) { + // Implement the actual model inference logic using session.Model and audioData + // For example: + // 1. Transcribe the audio to text + // 2. Generate a response based on the transcribed text + // 3. Check if the model wants to call a function + // 4. Convert the response text to speech (audio) + // + // Placeholder implementation: + + // TODO: template eventual messages, like chat.go + reply, err := session.ModelInterface.Predict(context.Background(), &proto.PredictOptions{ + Prompt: "What's the weather in New York?", + }) + + if err != nil { + return "", nil, nil, err + } + + generatedAudio := reply.Audio + + transcribedText := "What's the weather in New York?" + var functionCall *FunctionCall + + // Simulate a function call + if strings.Contains(transcribedText, "weather") { + functionCall = &FunctionCall{ + Name: "get_weather", + Arguments: map[string]interface{}{ + "location": "New York", + "scale": "celsius", + }, + } + return "", nil, functionCall, nil + } + + // Generate a response + generatedText := "This is a response to your speech input." + + return generatedText, generatedAudio, nil, nil +} + +// Function to split the response into chunks (for streaming) +func splitResponseIntoChunks(response string) []string { + // Split the response into chunks of fixed size + chunkSize := 50 // characters per chunk + var chunks []string + for len(response) > 0 { + if len(response) > chunkSize { + chunks = append(chunks, response[:chunkSize]) + response = response[chunkSize:] + } else { + chunks = append(chunks, response) + break + } + } + return chunks +} + +// Helper functions to generate unique IDs +func generateSessionID() string { + // Generate a unique session ID + // Implement as needed + return "sess_" + generateUniqueID() +} + +func generateConversationID() string { + // Generate a unique conversation ID + // Implement as needed + return "conv_" + generateUniqueID() +} + +func generateItemID() string { + // Generate a unique item ID + // Implement as needed + return "item_" + generateUniqueID() +} + +func generateUniqueID() string { + // Generate a unique ID string + // For simplicity, use a counter or UUID + // Implement as needed + return "unique_id" +} + +// Structures for 'response.create' messages +type ResponseCreate struct { + Modalities []string `json:"modalities,omitempty"` + Instructions string `json:"instructions,omitempty"` + Functions functions.Functions `json:"functions,omitempty"` + // Other fields as needed +} + +/* +func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, firstModel bool) func(c *websocket.Conn) { + return func(c *websocket.Conn) { + modelFile, input, err := readRequest(c, cl, ml, appConfig, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + var ( + mt int + msg []byte + err error + ) + for { + if mt, msg, err = c.ReadMessage(); err != nil { + log.Error().Msgf("read: %s", err.Error()) + break + } + log.Printf("recv: %s", msg) + + if err = c.WriteMessage(mt, msg); err != nil { + log.Error().Msgf("write: %s", err.Error()) + break + } + } + } +} + +*/ diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go new file mode 100644 index 00000000..815bbb1d --- /dev/null +++ b/core/http/endpoints/openai/realtime_model.go @@ -0,0 +1,186 @@ +package openai + +import ( + "context" + "fmt" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + grpcClient "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" + "github.com/rs/zerolog/log" + "google.golang.org/grpc" +) + +var ( + _ Model = new(wrappedModel) + _ Model = new(anyToAnyModel) +) + +// wrappedModel represent a model which does not support Any-to-Any operations +// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods +// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS) +type wrappedModel struct { + TTSConfig *config.BackendConfig + TranscriptionConfig *config.BackendConfig + LLMConfig *config.BackendConfig + TTSClient grpcClient.Backend + TranscriptionClient grpcClient.Backend + LLMClient grpcClient.Backend + + VADConfig *config.BackendConfig + VADClient grpcClient.Backend +} + +// anyToAnyModel represent a model which supports Any-to-Any operations +// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model. +// In the future there could be models that accept continous audio input only so this design will be useful for that +type anyToAnyModel struct { + LLMConfig *config.BackendConfig + LLMClient grpcClient.Backend + + VADConfig *config.BackendConfig + VADClient grpcClient.Backend +} + +func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { + return m.VADClient.VAD(ctx, in) +} + +func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) { + return m.VADClient.VAD(ctx, in) +} + +func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { + // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it) + // sound.BufferAsWAV(audioData, "audio.wav") + + return m.LLMClient.Predict(ctx, in) +} + +func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { + // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it) + + return m.LLMClient.PredictStream(ctx, in, f) +} + +func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) { + return m.LLMClient.Predict(ctx, in) +} + +func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { + return m.LLMClient.PredictStream(ctx, in, f) +} + +// returns and loads either a wrapped model or a model that support audio-to-audio +func newModel(cfg *config.BackendConfig, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) { + + // Prepare VAD model + cfgVAD, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.VAD, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgVAD.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts := backend.ModelOptions(*cfgVAD, appConfig) + VADClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load tts model: %w", err) + } + + // If we don't have Wrapped model definitions, just return a standard model + if cfg.Pipeline.IsNotConfigured() { + + // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations + cfgAnyToAny, err := cl.LoadBackendConfigFileByName(cfg.Model, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgAnyToAny.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts := backend.ModelOptions(*cfgAnyToAny, appConfig) + anyToAnyClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load tts model: %w", err) + } + + return &anyToAnyModel{ + LLMConfig: cfgAnyToAny, + LLMClient: anyToAnyClient, + VADConfig: cfgVAD, + VADClient: VADClient, + }, nil + } + + log.Debug().Msg("Loading a wrapped model") + + // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations + cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgLLM.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgTTS.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath) + if err != nil { + + return nil, fmt.Errorf("failed to load backend config: %w", err) + } + + if !cfgSST.Validate() { + return nil, fmt.Errorf("failed to validate config: %w", err) + } + + opts = backend.ModelOptions(*cfgTTS, appConfig) + ttsClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load tts model: %w", err) + } + + opts = backend.ModelOptions(*cfgSST, appConfig) + transcriptionClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load SST model: %w", err) + } + + opts = backend.ModelOptions(*cfgLLM, appConfig) + llmClient, err := ml.Load(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load LLM model: %w", err) + } + + return &wrappedModel{ + TTSConfig: cfgTTS, + TranscriptionConfig: cfgSST, + LLMConfig: cfgLLM, + TTSClient: ttsClient, + TranscriptionClient: transcriptionClient, + LLMClient: llmClient, + + VADConfig: cfgVAD, + VADClient: VADClient, + }, nil +} diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index fd17613a..cadd3d29 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -15,6 +15,9 @@ func RegisterOpenAIRoutes(app *fiber.App, application *application.Application) { // openAI compatible API endpoint + // realtime + app.Get("/v1/realtime", openai.Realtime(application)) + // chat chatChain := []fiber.Handler{ re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), diff --git a/go.mod b/go.mod index 5567c372..4d1079cb 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,8 @@ require ( github.com/GeertJohan/go.rice v1.0.3 github.com/Masterminds/sprig/v3 v3.3.0 github.com/alecthomas/kong v0.9.0 - github.com/census-instrumentation/opencensus-proto v0.4.1 github.com/charmbracelet/glamour v0.7.0 github.com/chasefleming/elem-go v0.26.0 - github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 github.com/containerd/containerd v1.7.19 github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 github.com/elliotchance/orderedmap/v2 v2.2.0 @@ -25,11 +23,9 @@ require ( github.com/gofiber/template/html/v2 v2.1.2 github.com/gofiber/websocket/v2 v2.2.1 github.com/gofrs/flock v0.12.1 - github.com/golang/protobuf v1.5.4 github.com/google/go-containerregistry v0.19.2 github.com/google/uuid v1.6.0 github.com/gpustack/gguf-parser-go v0.17.0 - github.com/grpc-ecosystem/grpc-gateway v1.5.0 github.com/hpcloud/tail v1.0.0 github.com/ipfs/go-log v1.0.5 github.com/jaypipes/ghw v0.12.0 @@ -43,7 +39,6 @@ require ( github.com/nikolalohinski/gonja/v2 v2.3.2 github.com/onsi/ginkgo/v2 v2.22.2 github.com/onsi/gomega v1.36.2 - github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e github.com/otiai10/openaigo v1.7.0 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/prometheus/client_golang v1.20.5 @@ -62,7 +57,6 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.50.0 go.opentelemetry.io/otel/metric v1.34.0 go.opentelemetry.io/otel/sdk/metric v1.28.0 - google.golang.org/api v0.180.0 google.golang.org/grpc v1.67.1 google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v2 v2.4.0 @@ -71,22 +65,14 @@ require ( ) require ( - cel.dev/expr v0.16.0 // indirect - cloud.google.com/go/auth v0.4.1 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.5.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/daaku/go.zipexe v1.0.2 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect - github.com/fasthttp/websocket v1.5.3 // indirect + github.com/fasthttp/websocket v1.5.8 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/google/s2a-go v0.1.7 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.4 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect @@ -119,15 +105,13 @@ require ( github.com/pion/turn/v4 v4.0.0 // indirect github.com/pion/webrtc/v4 v4.0.9 // indirect github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 // indirect - github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 // indirect github.com/shirou/gopsutil/v4 v4.24.7 // indirect github.com/wlynxg/anet v0.0.5 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect go.uber.org/mock v0.5.0 // indirect - golang.org/x/oauth2 v0.24.0 // indirect golang.org/x/time v0.8.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect ) require ( @@ -162,7 +146,7 @@ require ( github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/docker v27.1.1+incompatible github.com/docker/docker-credential-helpers v0.7.0 // indirect - github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-connections v0.5.0 github.com/docker/go-units v0.5.0 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect github.com/elastic/gosigar v0.14.3 // indirect diff --git a/go.sum b/go.sum index 6af7a14b..dae92fae 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,7 @@ -cel.dev/expr v0.16.0 h1:yloc84fytn4zmJX2GU3TkXGsaieaV7dQ057Qs4sIG2Y= -cel.dev/expr v0.16.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= -cloud.google.com/go/auth v0.4.1 h1:Z7YNIhlWRtrnKlZke7z3GMqzvuYzdc2z98F9D1NV5Hg= -cloud.google.com/go/auth v0.4.1/go.mod h1:QVBuVEKpCn4Zp58hzRGvL0tjRGU0YqdRTdCHM1IHnro= -cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= -cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= -cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= -cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= @@ -73,8 +65,6 @@ github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szN github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= -github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/glamour v0.7.0 h1:2BtKGZ4iVJCDfMF229EzbeR1QRKLWztO9dMtjmqZSng= @@ -84,8 +74,6 @@ github.com/chasefleming/elem-go v0.26.0/go.mod h1:hz73qILBIKnTgOujnSMtEj20/epI+f github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 h1:N+3sFI5GUjRKBi+i0TxYVST9h4Ie192jJWpHvthBBgg= -github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE= github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM= github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw= @@ -161,10 +149,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= -github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= -github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQtdek= -github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs= +github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8= +github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= @@ -177,6 +163,8 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20240626202019-c118733a29ad h1:dQ93Vd6i25o+zH9vvnZ8mu7jtJQ6jT3D+zE3V8Q49n0= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20240626202019-c118733a29ad/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= @@ -250,8 +238,6 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -280,18 +266,12 @@ github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OI github.com/google/pprof v0.0.0-20250208200701-d0013a598941 h1:43XjGa6toxLpeksjcxs1jIoIyr+vUfOqY2c6HB4bpoc= github.com/google/pprof v0.0.0-20250208200701-d0013a598941/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= -github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= -github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= -github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= -github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c h1:7lF+Vz0LqiRidnzC1Oq86fpX1q/iEv2KJdrCtttYjT4= github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -505,6 +485,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/mudler/edgevpn v0.30.1 h1:4yyhNFJX62NpRp50sxiyZE5E/sdAqEZX+aE5Mv7QS60= github.com/mudler/edgevpn v0.30.1/go.mod h1:IAJkkJ0oH3rwsSGOGTFT4UBYFqYuD/QyaKzTLB3P/eU= +github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA= +github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU= github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 h1:Qh6ghkMgTu6siFbTf7L3IszJmshMhXxNL4V+t7IIA6w= @@ -564,8 +546,6 @@ github.com/opencontainers/runtime-spec v1.2.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= -github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw= -github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0= github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg= github.com/otiai10/mint v1.6.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/otiai10/openaigo v1.7.0 h1:AOQcOjRRM57ABvz+aI2oJA/Qsz1AydKbdZAlGiKyCqg= @@ -681,8 +661,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sashabaranov/go-openai v1.26.2 h1:cVlQa3gn3eYqNXRW03pPlpy6zLG52EU4g0FrWXc0EFI= github.com/sashabaranov/go-openai v1.26.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= -github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= -github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= +github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 h1:KanIMPX0QdEdB4R3CiimCAbxFrhB3j7h0/OvpYGVQa8= +github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= github.com/schollz/progressbar/v3 v3.14.4 h1:W9ZrDSJk7eqmQhd3uxFNNcTr0QL+xuGNI9dEMrw0r74= github.com/schollz/progressbar/v3 v3.14.4/go.mod h1:aT3UQ7yGm+2ZjeXPqsjTenwL3ddUiuZ0kfQ/2tHlyNI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -925,8 +905,6 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= -golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1046,8 +1024,6 @@ gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= -google.golang.org/api v0.180.0 h1:M2D87Yo0rGBPWpo1orwfCLehUUL6E7/TYe5gvMQWDh4= -google.golang.org/api v0.180.0/go.mod h1:51AiyoEg1MJPSZ9zvklA8VnRILPXxn1iVen9v25XHAE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1060,7 +1036,6 @@ google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda h1:wu/KJm9KJwpfHWhkkZGohVC6KRrc1oJNr4jwtQMOQXw= -google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda/go.mod h1:g2LLCvCeCSir/JJSWosk19BR4NVxGqHUC6rxIRsd7Aw= google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 h1:T6rh4haD3GVYsgEfWExoCZA2o2FmbNyKpTuAxbEFPTg= google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:wp2WsuBYj6j8wUdo3ToZsdxxixbvQNAHqVJrTgi5E5M= google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 h1:QCqS/PdaHTSWGvupk2F/ehwHtGc0/GYkT+3GAcR1CCc= diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 9f9f19b1..ac3fe757 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -35,9 +35,9 @@ type Backend interface { IsBusy() bool HealthCheck(ctx context.Context) (bool, error) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) - Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error + Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) diff --git a/pkg/sound/float32.go b/pkg/sound/float32.go new file mode 100644 index 00000000..f42a04e5 --- /dev/null +++ b/pkg/sound/float32.go @@ -0,0 +1,12 @@ +package sound + +import ( + "encoding/binary" + "math" +) + +func BytesFloat32(bytes []byte) float32 { + bits := binary.LittleEndian.Uint32(bytes) + float := math.Float32frombits(bits) + return float +} diff --git a/pkg/sound/int16.go b/pkg/sound/int16.go new file mode 100644 index 00000000..237c805c --- /dev/null +++ b/pkg/sound/int16.go @@ -0,0 +1,78 @@ +package sound + +import "math" + +/* + +MIT License + +Copyright (c) 2024 Xbozon + +*/ + +// calculateRMS16 calculates the root mean square of the audio buffer for int16 samples. +func CalculateRMS16(buffer []int16) float64 { + var sumSquares float64 + for _, sample := range buffer { + val := float64(sample) // Convert int16 to float64 for calculation + sumSquares += val * val + } + meanSquares := sumSquares / float64(len(buffer)) + return math.Sqrt(meanSquares) +} + +func ResampleInt16(input []int16, inputRate, outputRate int) []int16 { + // Calculate the resampling ratio + ratio := float64(inputRate) / float64(outputRate) + + // Calculate the length of the resampled output + outputLength := int(float64(len(input)) / ratio) + + // Allocate a slice for the resampled output + output := make([]int16, outputLength) + + // Perform linear interpolation for resampling + for i := 0; i < outputLength-1; i++ { + // Calculate the corresponding position in the input + pos := float64(i) * ratio + + // Calculate the indices of the surrounding input samples + indexBefore := int(pos) + indexAfter := indexBefore + 1 + if indexAfter >= len(input) { + indexAfter = len(input) - 1 + } + + // Calculate the fractional part of the position + frac := pos - float64(indexBefore) + + // Linearly interpolate between the two surrounding input samples + output[i] = int16((1-frac)*float64(input[indexBefore]) + frac*float64(input[indexAfter])) + } + + // Handle the last sample explicitly to avoid index out of range + output[outputLength-1] = input[len(input)-1] + + return output +} + +func ConvertInt16ToInt(input []int16) []int { + output := make([]int, len(input)) // Allocate a slice for the output + for i, value := range input { + output[i] = int(value) // Convert each int16 to int and assign it to the output slice + } + return output // Return the converted slice +} + +func BytesToInt16sLE(bytes []byte) []int16 { + // Ensure the byte slice length is even + if len(bytes)%2 != 0 { + panic("bytesToInt16sLE: input bytes slice has odd length, must be even") + } + + int16s := make([]int16, len(bytes)/2) + for i := 0; i < len(int16s); i++ { + int16s[i] = int16(bytes[2*i]) | int16(bytes[2*i+1])<<8 + } + return int16s +}