mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
feat(realtime): Initial Realtime API implementation
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
456b4982ef
commit
ae6069a0a0
13 changed files with 1453 additions and 1135 deletions
|
@ -162,6 +162,7 @@ message Reply {
|
|||
int32 prompt_tokens = 3;
|
||||
double timing_prompt_processing = 4;
|
||||
double timing_token_generation = 5;
|
||||
bytes audio = 6;
|
||||
}
|
||||
|
||||
message GrammarTrigger {
|
||||
|
|
|
@ -21,8 +21,8 @@ func (vad *VAD) Load(opts *pb.ModelOptions) error {
|
|||
SampleRate: 16000,
|
||||
//WindowSize: 1024,
|
||||
Threshold: 0.5,
|
||||
MinSilenceDurationMs: 0,
|
||||
SpeechPadMs: 0,
|
||||
MinSilenceDurationMs: 100,
|
||||
SpeechPadMs: 30,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create silero detector: %w", err)
|
||||
|
|
|
@ -22,8 +22,9 @@ import (
|
|||
)
|
||||
|
||||
type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
|
|
|
@ -37,6 +37,7 @@ type BackendConfig struct {
|
|||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases"`
|
||||
KnownUsecases *BackendConfigUsecases `yaml:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-"`
|
||||
InputToken [][]int `yaml:"-"`
|
||||
|
@ -72,6 +73,18 @@ type BackendConfig struct {
|
|||
Options []string `yaml:"options"`
|
||||
}
|
||||
|
||||
// Pipeline defines other models to use for audio-to-audio
|
||||
type Pipeline struct {
|
||||
TTS string `yaml:"tts"`
|
||||
LLM string `yaml:"llm"`
|
||||
Transcription string `yaml:"transcription"`
|
||||
VAD string `yaml:"vad"`
|
||||
}
|
||||
|
||||
func (p Pipeline) IsNotConfigured() bool {
|
||||
return p.LLM == "" || p.TTS == "" || p.Transcription == ""
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Filename string `yaml:"filename" json:"filename"`
|
||||
SHA256 string `yaml:"sha256" json:"sha256"`
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"path/filepath"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
|
@ -91,6 +92,7 @@ func API(application *application.Application) (*fiber.App, error) {
|
|||
|
||||
router.Use(middleware.StripPathPrefix())
|
||||
|
||||
<<<<<<< HEAD
|
||||
if application.ApplicationConfig().MachineTag != "" {
|
||||
router.Use(func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
|
@ -98,6 +100,16 @@ func API(application *application.Application) (*fiber.App, error) {
|
|||
return c.Next()
|
||||
})
|
||||
}
|
||||
=======
|
||||
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
// Returns true if the client requested upgrade to the WebSocket protocol
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
>>>>>>> 43463868 (feat(realtime): Initial Realtime API implementation)
|
||||
|
||||
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||
scheme := "http"
|
||||
|
|
1136
core/http/endpoints/openai/realtime.go
Normal file
1136
core/http/endpoints/openai/realtime.go
Normal file
File diff suppressed because it is too large
Load diff
186
core/http/endpoints/openai/realtime_model.go
Normal file
186
core/http/endpoints/openai/realtime_model.go
Normal file
|
@ -0,0 +1,186 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
grpcClient "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Model = new(wrappedModel)
|
||||
_ Model = new(anyToAnyModel)
|
||||
)
|
||||
|
||||
// wrappedModel represent a model which does not support Any-to-Any operations
|
||||
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
|
||||
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
|
||||
type wrappedModel struct {
|
||||
TTSConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
LLMConfig *config.BackendConfig
|
||||
TTSClient grpcClient.Backend
|
||||
TranscriptionClient grpcClient.Backend
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
// anyToAnyModel represent a model which supports Any-to-Any operations
|
||||
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
|
||||
// In the future there could be models that accept continous audio input only so this design will be useful for that
|
||||
type anyToAnyModel struct {
|
||||
LLMConfig *config.BackendConfig
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
|
||||
return m.VADClient.VAD(ctx, in)
|
||||
}
|
||||
|
||||
func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
|
||||
return m.VADClient.VAD(ctx, in)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
|
||||
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
|
||||
// sound.BufferAsWAV(audioData, "audio.wav")
|
||||
|
||||
return m.LLMClient.Predict(ctx, in)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
|
||||
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
|
||||
|
||||
return m.LLMClient.PredictStream(ctx, in, f)
|
||||
}
|
||||
|
||||
func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
|
||||
return m.LLMClient.Predict(ctx, in)
|
||||
}
|
||||
|
||||
func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
|
||||
return m.LLMClient.PredictStream(ctx, in, f)
|
||||
}
|
||||
|
||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||
func newModel(cfg *config.BackendConfig, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {
|
||||
|
||||
// Prepare VAD model
|
||||
cfgVAD, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfgVAD.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
opts := backend.ModelOptions(*cfgVAD, appConfig)
|
||||
VADClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||
}
|
||||
|
||||
// If we don't have Wrapped model definitions, just return a standard model
|
||||
if cfg.Pipeline.IsNotConfigured() {
|
||||
|
||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(cfg.Model, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfgAnyToAny.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
opts := backend.ModelOptions(*cfgAnyToAny, appConfig)
|
||||
anyToAnyClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||
}
|
||||
|
||||
return &anyToAnyModel{
|
||||
LLMConfig: cfgAnyToAny,
|
||||
LLMClient: anyToAnyClient,
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug().Msg("Loading a wrapped model")
|
||||
|
||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfgLLM.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfgTTS.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfgSST.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
opts = backend.ModelOptions(*cfgTTS, appConfig)
|
||||
ttsClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||
}
|
||||
|
||||
opts = backend.ModelOptions(*cfgSST, appConfig)
|
||||
transcriptionClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load SST model: %w", err)
|
||||
}
|
||||
|
||||
opts = backend.ModelOptions(*cfgLLM, appConfig)
|
||||
llmClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
||||
}
|
||||
|
||||
return &wrappedModel{
|
||||
TTSConfig: cfgTTS,
|
||||
TranscriptionConfig: cfgSST,
|
||||
LLMConfig: cfgLLM,
|
||||
TTSClient: ttsClient,
|
||||
TranscriptionClient: transcriptionClient,
|
||||
LLMClient: llmClient,
|
||||
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
}, nil
|
||||
}
|
|
@ -15,6 +15,9 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
|||
application *application.Application) {
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// realtime
|
||||
app.Get("/v1/realtime", openai.Realtime(application))
|
||||
|
||||
// chat
|
||||
chatChain := []fiber.Handler{
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
|
|
7
go.mod
7
go.mod
|
@ -40,6 +40,7 @@ require (
|
|||
github.com/microcosm-cc/bluemonday v1.0.26
|
||||
github.com/mudler/edgevpn v0.30.1
|
||||
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f
|
||||
github.com/nikolalohinski/gonja/v2 v2.3.2
|
||||
github.com/onsi/ginkgo/v2 v2.22.2
|
||||
github.com/onsi/gomega v1.36.2
|
||||
|
@ -81,7 +82,7 @@ require (
|
|||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/fasthttp/websocket v1.5.8 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
|
@ -123,6 +124,10 @@ require (
|
|||
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
github.com/pion/webrtc/v3 v3.3.5 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
golang.org/x/oauth2 v0.24.0 // indirect
|
||||
|
|
|
@ -35,9 +35,9 @@ type Backend interface {
|
|||
IsBusy() bool
|
||||
HealthCheck(ctx context.Context) (bool, error)
|
||||
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
|
||||
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
|
||||
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
|
||||
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
|
||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
|
|
12
pkg/sound/float32.go
Normal file
12
pkg/sound/float32.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package sound
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
)
|
||||
|
||||
func BytesFloat32(bytes []byte) float32 {
|
||||
bits := binary.LittleEndian.Uint32(bytes)
|
||||
float := math.Float32frombits(bits)
|
||||
return float
|
||||
}
|
78
pkg/sound/int16.go
Normal file
78
pkg/sound/int16.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package sound
|
||||
|
||||
import "math"
|
||||
|
||||
/*
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Xbozon
|
||||
|
||||
*/
|
||||
|
||||
// calculateRMS16 calculates the root mean square of the audio buffer for int16 samples.
|
||||
func CalculateRMS16(buffer []int16) float64 {
|
||||
var sumSquares float64
|
||||
for _, sample := range buffer {
|
||||
val := float64(sample) // Convert int16 to float64 for calculation
|
||||
sumSquares += val * val
|
||||
}
|
||||
meanSquares := sumSquares / float64(len(buffer))
|
||||
return math.Sqrt(meanSquares)
|
||||
}
|
||||
|
||||
func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
|
||||
// Calculate the resampling ratio
|
||||
ratio := float64(inputRate) / float64(outputRate)
|
||||
|
||||
// Calculate the length of the resampled output
|
||||
outputLength := int(float64(len(input)) / ratio)
|
||||
|
||||
// Allocate a slice for the resampled output
|
||||
output := make([]int16, outputLength)
|
||||
|
||||
// Perform linear interpolation for resampling
|
||||
for i := 0; i < outputLength-1; i++ {
|
||||
// Calculate the corresponding position in the input
|
||||
pos := float64(i) * ratio
|
||||
|
||||
// Calculate the indices of the surrounding input samples
|
||||
indexBefore := int(pos)
|
||||
indexAfter := indexBefore + 1
|
||||
if indexAfter >= len(input) {
|
||||
indexAfter = len(input) - 1
|
||||
}
|
||||
|
||||
// Calculate the fractional part of the position
|
||||
frac := pos - float64(indexBefore)
|
||||
|
||||
// Linearly interpolate between the two surrounding input samples
|
||||
output[i] = int16((1-frac)*float64(input[indexBefore]) + frac*float64(input[indexAfter]))
|
||||
}
|
||||
|
||||
// Handle the last sample explicitly to avoid index out of range
|
||||
output[outputLength-1] = input[len(input)-1]
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func ConvertInt16ToInt(input []int16) []int {
|
||||
output := make([]int, len(input)) // Allocate a slice for the output
|
||||
for i, value := range input {
|
||||
output[i] = int(value) // Convert each int16 to int and assign it to the output slice
|
||||
}
|
||||
return output // Return the converted slice
|
||||
}
|
||||
|
||||
func BytesToInt16sLE(bytes []byte) []int16 {
|
||||
// Ensure the byte slice length is even
|
||||
if len(bytes)%2 != 0 {
|
||||
panic("bytesToInt16sLE: input bytes slice has odd length, must be even")
|
||||
}
|
||||
|
||||
int16s := make([]int16, len(bytes)/2)
|
||||
for i := 0; i < len(int16s); i++ {
|
||||
int16s[i] = int16(bytes[2*i]) | int16(bytes[2*i+1])<<8
|
||||
}
|
||||
return int16s
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue