feat(realtime): Initial Realtime API implementation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-10-02 11:16:11 +02:00 committed by Richard Palethorpe
parent 456b4982ef
commit ae6069a0a0
13 changed files with 1453 additions and 1135 deletions

View file

@ -162,6 +162,7 @@ message Reply {
int32 prompt_tokens = 3;
double timing_prompt_processing = 4;
double timing_token_generation = 5;
bytes audio = 6;
}
message GrammarTrigger {

View file

@ -21,8 +21,8 @@ func (vad *VAD) Load(opts *pb.ModelOptions) error {
SampleRate: 16000,
//WindowSize: 1024,
Threshold: 0.5,
MinSilenceDurationMs: 0,
SpeechPadMs: 0,
MinSilenceDurationMs: 100,
SpeechPadMs: 30,
})
if err != nil {
return fmt.Errorf("create silero detector: %w", err)

View file

@ -24,6 +24,7 @@ import (
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
AudioOutput string
}
type TokenUsage struct {

View file

@ -37,6 +37,7 @@ type BackendConfig struct {
TemplateConfig TemplateConfig `yaml:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases"`
KnownUsecases *BackendConfigUsecases `yaml:"-"`
Pipeline Pipeline `yaml:"pipeline"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
@ -72,6 +73,18 @@ type BackendConfig struct {
Options []string `yaml:"options"`
}
// Pipeline defines other models to use for audio-to-audio
type Pipeline struct {
TTS string `yaml:"tts"`
LLM string `yaml:"llm"`
Transcription string `yaml:"transcription"`
VAD string `yaml:"vad"`
}
func (p Pipeline) IsNotConfigured() bool {
return p.LLM == "" || p.TTS == "" || p.Transcription == ""
}
type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`

View file

@ -9,6 +9,7 @@ import (
"path/filepath"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
@ -91,6 +92,7 @@ func API(application *application.Application) (*fiber.App, error) {
router.Use(middleware.StripPathPrefix())
<<<<<<< HEAD
if application.ApplicationConfig().MachineTag != "" {
router.Use(func(c *fiber.Ctx) error {
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
@ -98,6 +100,16 @@ func API(application *application.Application) (*fiber.App, error) {
return c.Next()
})
}
=======
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
// Returns true if the client requested upgrade to the WebSocket protocol
return c.Next()
}
return nil
})
>>>>>>> 43463868 (feat(realtime): Initial Realtime API implementation)
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,186 @@
package openai
import (
"context"
"fmt"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
grpcClient "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"google.golang.org/grpc"
)
var (
_ Model = new(wrappedModel)
_ Model = new(anyToAnyModel)
)
// wrappedModel represent a model which does not support Any-to-Any operations
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
type wrappedModel struct {
TTSConfig *config.BackendConfig
TranscriptionConfig *config.BackendConfig
LLMConfig *config.BackendConfig
TTSClient grpcClient.Backend
TranscriptionClient grpcClient.Backend
LLMClient grpcClient.Backend
VADConfig *config.BackendConfig
VADClient grpcClient.Backend
}
// anyToAnyModel represent a model which supports Any-to-Any operations
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
// In the future there could be models that accept continous audio input only so this design will be useful for that
type anyToAnyModel struct {
LLMConfig *config.BackendConfig
LLMClient grpcClient.Backend
VADConfig *config.BackendConfig
VADClient grpcClient.Backend
}
func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
return m.VADClient.VAD(ctx, in)
}
func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
return m.VADClient.VAD(ctx, in)
}
func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
// sound.BufferAsWAV(audioData, "audio.wav")
return m.LLMClient.Predict(ctx, in)
}
func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
return m.LLMClient.PredictStream(ctx, in, f)
}
func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
return m.LLMClient.Predict(ctx, in)
}
func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
return m.LLMClient.PredictStream(ctx, in, f)
}
// returns and loads either a wrapped model or a model that support audio-to-audio
func newModel(cfg *config.BackendConfig, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {
// Prepare VAD model
cfgVAD, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.VAD, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
}
if !cfgVAD.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}
opts := backend.ModelOptions(*cfgVAD, appConfig)
VADClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load tts model: %w", err)
}
// If we don't have Wrapped model definitions, just return a standard model
if cfg.Pipeline.IsNotConfigured() {
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(cfg.Model, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
}
if !cfgAnyToAny.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}
opts := backend.ModelOptions(*cfgAnyToAny, appConfig)
anyToAnyClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load tts model: %w", err)
}
return &anyToAnyModel{
LLMConfig: cfgAnyToAny,
LLMClient: anyToAnyClient,
VADConfig: cfgVAD,
VADClient: VADClient,
}, nil
}
log.Debug().Msg("Loading a wrapped model")
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
}
if !cfgLLM.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}
cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
}
if !cfgTTS.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}
cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
}
if !cfgSST.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}
opts = backend.ModelOptions(*cfgTTS, appConfig)
ttsClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load tts model: %w", err)
}
opts = backend.ModelOptions(*cfgSST, appConfig)
transcriptionClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load SST model: %w", err)
}
opts = backend.ModelOptions(*cfgLLM, appConfig)
llmClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}
return &wrappedModel{
TTSConfig: cfgTTS,
TranscriptionConfig: cfgSST,
LLMConfig: cfgLLM,
TTSClient: ttsClient,
TranscriptionClient: transcriptionClient,
LLMClient: llmClient,
VADConfig: cfgVAD,
VADClient: VADClient,
}, nil
}

View file

@ -15,6 +15,9 @@ func RegisterOpenAIRoutes(app *fiber.App,
application *application.Application) {
// openAI compatible API endpoint
// realtime
app.Get("/v1/realtime", openai.Realtime(application))
// chat
chatChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),

7
go.mod
View file

@ -40,6 +40,7 @@ require (
github.com/microcosm-cc/bluemonday v1.0.26
github.com/mudler/edgevpn v0.30.1
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82
github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f
github.com/nikolalohinski/gonja/v2 v2.3.2
github.com/onsi/ginkgo/v2 v2.22.2
github.com/onsi/gomega v1.36.2
@ -81,7 +82,7 @@ require (
github.com/distribution/reference v0.6.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/fasthttp/websocket v1.5.8 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/s2a-go v0.1.7 // indirect
@ -123,6 +124,10 @@ require (
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
github.com/pion/webrtc/v3 v3.3.5 // indirect
github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 // indirect
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.uber.org/mock v0.5.0 // indirect
golang.org/x/oauth2 v0.24.0 // indirect

1129
go.sum

File diff suppressed because it is too large Load diff

View file

@ -35,9 +35,9 @@ type Backend interface {
IsBusy() bool
HealthCheck(ctx context.Context) (bool, error)
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)

12
pkg/sound/float32.go Normal file
View file

@ -0,0 +1,12 @@
package sound
import (
"encoding/binary"
"math"
)
func BytesFloat32(bytes []byte) float32 {
bits := binary.LittleEndian.Uint32(bytes)
float := math.Float32frombits(bits)
return float
}

78
pkg/sound/int16.go Normal file
View file

@ -0,0 +1,78 @@
package sound
import "math"
/*
MIT License
Copyright (c) 2024 Xbozon
*/
// calculateRMS16 calculates the root mean square of the audio buffer for int16 samples.
func CalculateRMS16(buffer []int16) float64 {
var sumSquares float64
for _, sample := range buffer {
val := float64(sample) // Convert int16 to float64 for calculation
sumSquares += val * val
}
meanSquares := sumSquares / float64(len(buffer))
return math.Sqrt(meanSquares)
}
func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
// Calculate the resampling ratio
ratio := float64(inputRate) / float64(outputRate)
// Calculate the length of the resampled output
outputLength := int(float64(len(input)) / ratio)
// Allocate a slice for the resampled output
output := make([]int16, outputLength)
// Perform linear interpolation for resampling
for i := 0; i < outputLength-1; i++ {
// Calculate the corresponding position in the input
pos := float64(i) * ratio
// Calculate the indices of the surrounding input samples
indexBefore := int(pos)
indexAfter := indexBefore + 1
if indexAfter >= len(input) {
indexAfter = len(input) - 1
}
// Calculate the fractional part of the position
frac := pos - float64(indexBefore)
// Linearly interpolate between the two surrounding input samples
output[i] = int16((1-frac)*float64(input[indexBefore]) + frac*float64(input[indexAfter]))
}
// Handle the last sample explicitly to avoid index out of range
output[outputLength-1] = input[len(input)-1]
return output
}
func ConvertInt16ToInt(input []int16) []int {
output := make([]int, len(input)) // Allocate a slice for the output
for i, value := range input {
output[i] = int(value) // Convert each int16 to int and assign it to the output slice
}
return output // Return the converted slice
}
func BytesToInt16sLE(bytes []byte) []int16 {
// Ensure the byte slice length is even
if len(bytes)%2 != 0 {
panic("bytesToInt16sLE: input bytes slice has odd length, must be even")
}
int16s := make([]int16, len(bytes)/2)
for i := 0; i < len(int16s); i++ {
int16s[i] = int16(bytes[2*i]) | int16(bytes[2*i+1])<<8
}
return int16s
}