mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
Small adaptations
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
06e438d68b
commit
c526f05de5
5 changed files with 26 additions and 9 deletions
|
@ -10,7 +10,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
|
@ -121,10 +123,14 @@ var sessionLock sync.Mutex
|
|||
type Model interface {
|
||||
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
|
||||
Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
|
||||
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
||||
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
|
||||
}
|
||||
|
||||
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
|
||||
func Realtime(application *application.Application) fiber.Handler {
|
||||
return websocket.New(registerRealtime(application))
|
||||
}
|
||||
|
||||
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
|
||||
return func(c *websocket.Conn) {
|
||||
|
||||
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
|
||||
|
@ -153,7 +159,12 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||
session.Conversations[conversationID] = conversation
|
||||
session.DefaultConversationID = conversationID
|
||||
|
||||
m, err := newModel(cl, ml, appConfig, model)
|
||||
m, err := newModel(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.ApplicationConfig(),
|
||||
model,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Msgf("failed to load model: %s", err.Error())
|
||||
sendError(c, "model_load_error", "Failed to load model", "", "")
|
||||
|
@ -210,7 +221,13 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||
sendError(c, "invalid_session_update", "Invalid session update format", "", "")
|
||||
continue
|
||||
}
|
||||
if err := updateSession(session, &sessionUpdate, cl, ml, appConfig); err != nil {
|
||||
if err := updateSession(
|
||||
session,
|
||||
&sessionUpdate,
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.ApplicationConfig(),
|
||||
); err != nil {
|
||||
log.Error().Msgf("failed to update session: %s", err.Error())
|
||||
sendError(c, "session_update_error", "Failed to update session", "", "")
|
||||
continue
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue