diff --git a/core/application/application.go b/core/application/application.go index 6e8d6204..8c9842d9 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -16,7 +16,7 @@ type Application struct { func newApplication(appConfig *config.ApplicationConfig) *Application { return &Application{ backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath), - modelLoader: model.NewModelLoader(appConfig.ModelPath), + modelLoader: model.NewModelLoader(appConfig.ModelPath, appConfig.SingleBackend), applicationConfig: appConfig, templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath), } diff --git a/core/application/startup.go b/core/application/startup.go index 3cfbd684..6c93f03f 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -143,7 +143,7 @@ func New(opts ...config.AppOption) (*Application, error) { }() } - if options.LoadToMemory != nil { + if options.LoadToMemory != nil && !options.SingleBackend { for _, m := range options.LoadToMemory { cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options) if err != nil { diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index a96e9829..aece0cdd 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -17,6 +17,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo if err != nil { return nil, err } + defer loader.Close() var fn func() ([]float32, error) switch model := inferenceModel.(type) { diff --git a/core/backend/image.go b/core/backend/image.go index 38ca4357..4b34f2cf 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -16,6 +16,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat if err != nil { return nil, err } + defer loader.Close() fn := func() error { _, err := inferenceModel.GenerateImage( diff --git a/core/backend/llm.go b/core/backend/llm.go index 14eb8569..57e2ae35 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -53,6 +53,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im if err != nil { return nil, err } + defer loader.Close() var protoMessages []*proto.Message // if we are using the tokenizer template, we need to convert the messages to proto messages diff --git a/core/backend/options.go b/core/backend/options.go index d98e136c..7a7a69bb 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -40,10 +40,6 @@ func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ... grpcOpts := grpcModelOpts(c) defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts)) - if so.SingleBackend { - defOpts = append(defOpts, model.WithSingleActiveBackend()) - } - if so.ParallelBackendRequests { defOpts = append(defOpts, model.EnableParallelRequests) } @@ -121,7 +117,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { triggers := make([]*pb.GrammarTrigger, 0) for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers { triggers = append(triggers, &pb.GrammarTrigger{ - Word: t.Word, + Word: t.Word, }) } @@ -161,33 +157,33 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { DisableLogStatus: c.DisableLogStatus, DType: c.DType, // LimitMMPerPrompt vLLM - LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt), - LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt), - LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt), - MMProj: c.MMProj, - FlashAttention: c.FlashAttention, - CacheTypeKey: c.CacheTypeK, - CacheTypeValue: c.CacheTypeV, - NoKVOffload: c.NoKVOffloading, - YarnExtFactor: c.YarnExtFactor, - YarnAttnFactor: c.YarnAttnFactor, - YarnBetaFast: c.YarnBetaFast, - YarnBetaSlow: c.YarnBetaSlow, - NGQA: c.NGQA, - RMSNormEps: c.RMSNormEps, - MLock: mmlock, - RopeFreqBase: c.RopeFreqBase, - RopeScaling: c.RopeScaling, - Type: c.ModelType, - RopeFreqScale: c.RopeFreqScale, - NUMA: c.NUMA, - Embeddings: embeddings, - LowVRAM: lowVRAM, - NGPULayers: int32(nGPULayers), - MMap: mmap, - MainGPU: c.MainGPU, - Threads: int32(*c.Threads), - TensorSplit: c.TensorSplit, + LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt), + LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt), + LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt), + MMProj: c.MMProj, + FlashAttention: c.FlashAttention, + CacheTypeKey: c.CacheTypeK, + CacheTypeValue: c.CacheTypeV, + NoKVOffload: c.NoKVOffloading, + YarnExtFactor: c.YarnExtFactor, + YarnAttnFactor: c.YarnAttnFactor, + YarnBetaFast: c.YarnBetaFast, + YarnBetaSlow: c.YarnBetaSlow, + NGQA: c.NGQA, + RMSNormEps: c.RMSNormEps, + MLock: mmlock, + RopeFreqBase: c.RopeFreqBase, + RopeScaling: c.RopeScaling, + Type: c.ModelType, + RopeFreqScale: c.RopeFreqScale, + NUMA: c.NUMA, + Embeddings: embeddings, + LowVRAM: lowVRAM, + NGPULayers: int32(nGPULayers), + MMap: mmap, + MainGPU: c.MainGPU, + Threads: int32(*c.Threads), + TensorSplit: c.TensorSplit, // AutoGPTQ ModelBaseName: c.AutoGPTQ.ModelBaseName, Device: c.AutoGPTQ.Device, diff --git a/core/backend/rerank.go b/core/backend/rerank.go index da565620..d7937ce4 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -12,10 +12,10 @@ import ( func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { opts := ModelOptions(backendConfig, appConfig) rerankModel, err := loader.Load(opts...) - if err != nil { return nil, err } + defer loader.Close() if rerankModel == nil { return nil, fmt.Errorf("could not load rerank model") diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index 49813d82..94ec9c89 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -26,10 +26,10 @@ func SoundGeneration( opts := ModelOptions(backendConfig, appConfig) soundGenModel, err := loader.Load(opts...) - if err != nil { return "", nil, err } + defer loader.Close() if soundGenModel == nil { return "", nil, fmt.Errorf("could not load sound generation model") diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go index cc71c868..ac34e34f 100644 --- a/core/backend/token_metrics.go +++ b/core/backend/token_metrics.go @@ -20,6 +20,7 @@ func TokenMetrics( if err != nil { return nil, err } + defer loader.Close() if model == nil { return nil, fmt.Errorf("could not loadmodel model") diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index e04a59d8..43c46134 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -14,10 +14,10 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac opts := ModelOptions(backendConfig, appConfig) inferenceModel, err = loader.Load(opts...) - if err != nil { return schema.TokenizeResponse{}, err } + defer loader.Close() predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath) predictOptions.Prompt = s diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 080f43b1..64f9c5e2 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -24,6 +24,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL if err != nil { return nil, err } + defer ml.Close() if transcriptionModel == nil { return nil, fmt.Errorf("could not load transcription model") diff --git a/core/backend/tts.go b/core/backend/tts.go index e6191cfb..6157f4c1 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -23,10 +23,10 @@ func ModelTTS( ) (string, *proto.Result, error) { opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend)) ttsModel, err := loader.Load(opts...) - if err != nil { return "", nil, err } + defer loader.Close() if ttsModel == nil { return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model) diff --git a/core/backend/vad.go b/core/backend/vad.go index 8d148353..741dbb19 100644 --- a/core/backend/vad.go +++ b/core/backend/vad.go @@ -19,6 +19,8 @@ func VAD(request *schema.VADRequest, if err != nil { return nil, err } + defer ml.Close() + req := proto.VADRequest{ Audio: request.Audio, } diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go index a8acd6ba..3c7e9af4 100644 --- a/core/cli/soundgeneration.go +++ b/core/cli/soundgeneration.go @@ -74,7 +74,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { AssetsDestination: t.BackendAssetsPath, ExternalGRPCBackends: externalBackends, } - ml := model.NewModelLoader(opts.ModelPath) + ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) defer func() { err := ml.StopAllGRPC() diff --git a/core/cli/transcript.go b/core/cli/transcript.go index 7f5e6a9d..67b5ed1d 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -32,7 +32,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { } cl := config.NewBackendConfigLoader(t.ModelsPath) - ml := model.NewModelLoader(opts.ModelPath) + ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) if err := cl.LoadBackendConfigsFromPath(t.ModelsPath); err != nil { return err } diff --git a/core/cli/tts.go b/core/cli/tts.go index af51ce06..283372fe 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -41,7 +41,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error { AudioDir: outputDir, AssetsDestination: t.BackendAssetsPath, } - ml := model.NewModelLoader(opts.ModelPath) + ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) defer func() { err := ml.StopAllGRPC() diff --git a/core/http/endpoints/localai/stores.go b/core/http/endpoints/localai/stores.go index f417c580..dd8df8b1 100644 --- a/core/http/endpoints/localai/stores.go +++ b/core/http/endpoints/localai/stores.go @@ -21,6 +21,7 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi if err != nil { return err } + defer sl.Close() vals := make([][]byte, len(input.Values)) for i, v := range input.Values { @@ -48,6 +49,7 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo if err != nil { return err } + defer sl.Close() if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil { return err @@ -69,6 +71,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi if err != nil { return err } + defer sl.Close() keys, vals, err := store.GetCols(c.Context(), sb, input.Keys) if err != nil { @@ -100,6 +103,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf if err != nil { return err } + defer sl.Close() keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk) if err != nil { diff --git a/core/http/endpoints/openai/assistant_test.go b/core/http/endpoints/openai/assistant_test.go index 6858f65d..90edb935 100644 --- a/core/http/endpoints/openai/assistant_test.go +++ b/core/http/endpoints/openai/assistant_test.go @@ -40,7 +40,7 @@ func TestAssistantEndpoints(t *testing.T) { cl := &config.BackendConfigLoader{} //configsDir := "/tmp/localai/configs" modelPath := "/tmp/localai/model" - var ml = model.NewModelLoader(modelPath) + var ml = model.NewModelLoader(modelPath, false) appConfig := &config.ApplicationConfig{ ConfigsDir: configsDir, diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 20c571fd..ebf9c1c9 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -50,11 +50,10 @@ func RegisterLocalAIRoutes(router *fiber.App, router.Post("/v1/vad", vadChain...) // Stores - sl := model.NewModelLoader("") - router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig)) - router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig)) - router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig)) - router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig)) + router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig)) + router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig)) + router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig)) + router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig)) if !appConfig.DisableMetrics { router.Get("/metrics", localai.LocalAIMetricsEndpoint()) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 12a1a972..1a7fdc9c 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -509,7 +509,23 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo } } +func (ml *ModelLoader) Close() { + if !ml.singletonMode { + return + } + ml.singletonLock.Unlock() +} + +func (ml *ModelLoader) lockBackend() { + if !ml.singletonMode { + return + } + ml.singletonLock.Lock() +} + func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { + ml.lockBackend() // grab the singleton lock if needed + o := NewOptions(opts...) // Return earlier if we have a model already loaded @@ -520,7 +536,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { return m.GRPC(o.parallelRequests, ml.wd), nil } - ml.stopActiveBackends(o.modelID, o.singleActiveBackend) + ml.stopActiveBackends(o.modelID, ml.singletonMode) // if a backend is defined, return the loader directly if o.backendString != "" { @@ -533,6 +549,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { // get backends embedded in the binary autoLoadBackends, err := ml.ListAvailableBackends(o.assetDir) if err != nil { + ml.Close() // we failed, release the lock return nil, err } @@ -564,5 +581,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { } } + ml.Close() // make sure to release the lock in case of failure + return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error()) } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index c25662d3..e74ea97b 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -18,16 +18,19 @@ import ( // TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl type ModelLoader struct { - ModelPath string - mu sync.Mutex - models map[string]*Model - wd *WatchDog + ModelPath string + mu sync.Mutex + singletonLock sync.Mutex + singletonMode bool + models map[string]*Model + wd *WatchDog } -func NewModelLoader(modelPath string) *ModelLoader { +func NewModelLoader(modelPath string, singleActiveBackend bool) *ModelLoader { nml := &ModelLoader{ - ModelPath: modelPath, - models: make(map[string]*Model), + ModelPath: modelPath, + models: make(map[string]*Model), + singletonMode: singleActiveBackend, } return nml diff --git a/pkg/model/loader_options.go b/pkg/model/loader_options.go index c151d53b..28a7c598 100644 --- a/pkg/model/loader_options.go +++ b/pkg/model/loader_options.go @@ -17,10 +17,9 @@ type Options struct { externalBackends map[string]string - grpcAttempts int - grpcAttemptsDelay int - singleActiveBackend bool - parallelRequests bool + grpcAttempts int + grpcAttemptsDelay int + parallelRequests bool } type Option func(*Options) @@ -88,12 +87,6 @@ func WithContext(ctx context.Context) Option { } } -func WithSingleActiveBackend() Option { - return func(o *Options) { - o.singleActiveBackend = true - } -} - func WithModelID(id string) Option { return func(o *Options) { o.modelID = id diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go index 83e47ec6..a8e77bd2 100644 --- a/pkg/model/loader_test.go +++ b/pkg/model/loader_test.go @@ -21,7 +21,7 @@ var _ = Describe("ModelLoader", func() { // Setup the model loader with a test directory modelPath = "/tmp/test_model_path" os.Mkdir(modelPath, 0755) - modelLoader = model.NewModelLoader(modelPath) + modelLoader = model.NewModelLoader(modelPath, false) }) AfterEach(func() { diff --git a/tests/integration/stores_test.go b/tests/integration/stores_test.go index 9612bec0..5484a79c 100644 --- a/tests/integration/stores_test.go +++ b/tests/integration/stores_test.go @@ -70,7 +70,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" model.WithModel("test"), } - sl = model.NewModelLoader("") + sl = model.NewModelLoader("", false) sc, err = sl.Load(storeOpts...) Expect(err).ToNot(HaveOccurred()) Expect(sc).ToNot(BeNil()) @@ -235,7 +235,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" keys := [][]float32{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, {-1.0, 0.0, 0.0}} vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")} - err := store.SetCols(context.Background(), sc, keys, vals); + err := store.SetCols(context.Background(), sc, keys, vals) Expect(err).ToNot(HaveOccurred()) _, _, sims, err := store.Find(context.Background(), sc, keys[0], 4) @@ -247,7 +247,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" keys := [][]float32{{1.0, 0.0, 1.0}, {0.0, 2.0, 0.0}, {0.0, 0.0, -1.0}, {-1.0, 0.0, -1.0}} vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")} - err := store.SetCols(context.Background(), sc, keys, vals); + err := store.SetCols(context.Background(), sc, keys, vals) Expect(err).ToNot(HaveOccurred()) _, _, sims, err := store.Find(context.Background(), sc, keys[0], 4) @@ -314,7 +314,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" normalize(keys[6:]) - err := store.SetCols(context.Background(), sc, keys, vals); + err := store.SetCols(context.Background(), sc, keys, vals) Expect(err).ToNot(HaveOccurred()) expectTriangleEq(keys, vals) @@ -341,7 +341,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" c += 1 } - err := store.SetCols(context.Background(), sc, keys, vals); + err := store.SetCols(context.Background(), sc, keys, vals) Expect(err).ToNot(HaveOccurred()) expectTriangleEq(keys, vals)