mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-27 22:15:00 +00:00
chore: extract realtime models into two categories
One is anyToAny models that requires a VAD model, and one is wrappedModel that requires as well VAD models along others in the pipeline. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
4f69170273
commit
1796a1713d
3 changed files with 178 additions and 91 deletions
|
@ -1,6 +1,7 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -8,10 +9,10 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -114,95 +115,7 @@ var sessionLock sync.Mutex
|
|||
|
||||
// TODO: implement interface as we start to define usages
|
||||
type Model interface {
|
||||
}
|
||||
|
||||
type wrappedModel struct {
|
||||
TTSConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
LLMConfig *config.BackendConfig
|
||||
TTSClient grpc.Backend
|
||||
TranscriptionClient grpc.Backend
|
||||
LLMClient grpc.Backend
|
||||
}
|
||||
|
||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelName, ml.ModelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfg.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Pipeline.LLM == "" || cfg.Pipeline.TTS == "" || cfg.Pipeline.Transcription == "" {
|
||||
// If we don't have Wrapped model definitions, just return a standard model
|
||||
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend),
|
||||
model.WithModel(cfg.Model))
|
||||
return ml.Load(opts...)
|
||||
}
|
||||
|
||||
log.Debug().Msg("Loading a wrapped model")
|
||||
|
||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfg.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfg.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
}
|
||||
|
||||
if !cfg.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
opts := backend.ModelOptions(*cfgTTS, appConfig)
|
||||
ttsClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||
}
|
||||
|
||||
opts = backend.ModelOptions(*cfgSST, appConfig)
|
||||
transcriptionClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load SST model: %w", err)
|
||||
}
|
||||
|
||||
opts = backend.ModelOptions(*cfgLLM, appConfig)
|
||||
llmClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
||||
}
|
||||
|
||||
return &wrappedModel{
|
||||
TTSConfig: cfgTTS,
|
||||
TranscriptionConfig: cfgSST,
|
||||
LLMConfig: cfgLLM,
|
||||
TTSClient: ttsClient,
|
||||
TranscriptionClient: transcriptionClient,
|
||||
LLMClient: llmClient,
|
||||
}, nil
|
||||
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
|
||||
}
|
||||
|
||||
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue