Small adaptations

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-12-27 18:39:56 +01:00
parent 06e438d68b
commit c526f05de5
5 changed files with 26 additions and 9 deletions

View file

@ -10,7 +10,9 @@ import (
"time"
"github.com/go-audio/audio"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
@ -121,10 +123,14 @@ var sessionLock sync.Mutex
type Model interface {
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
}
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
func Realtime(application *application.Application) fiber.Handler {
return websocket.New(registerRealtime(application))
}
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
return func(c *websocket.Conn) {
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
@ -153,7 +159,12 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
session.Conversations[conversationID] = conversation
session.DefaultConversationID = conversationID
m, err := newModel(cl, ml, appConfig, model)
m, err := newModel(
application.BackendLoader(),
application.ModelLoader(),
application.ApplicationConfig(),
model,
)
if err != nil {
log.Error().Msgf("failed to load model: %s", err.Error())
sendError(c, "model_load_error", "Failed to load model", "", "")
@ -210,7 +221,13 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
sendError(c, "invalid_session_update", "Invalid session update format", "", "")
continue
}
if err := updateSession(session, &sessionUpdate, cl, ml, appConfig); err != nil {
if err := updateSession(
session,
&sessionUpdate,
application.BackendLoader(),
application.ModelLoader(),
application.ApplicationConfig(),
); err != nil {
log.Error().Msgf("failed to update session: %s", err.Error())
sendError(c, "session_update_error", "Failed to update session", "", "")
continue