mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-19 18:15:00 +00:00
feat(loader): enhance single active backend by treating as singleton (#5107)
feat(loader): enhance single active backend by treating at singleton Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
c59975ab05
commit
2c425e9c69
24 changed files with 92 additions and 71 deletions
|
@ -16,7 +16,7 @@ type Application struct {
|
||||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||||
return &Application{
|
return &Application{
|
||||||
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
||||||
modelLoader: model.NewModelLoader(appConfig.ModelPath),
|
modelLoader: model.NewModelLoader(appConfig.ModelPath, appConfig.SingleBackend),
|
||||||
applicationConfig: appConfig,
|
applicationConfig: appConfig,
|
||||||
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
|
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
|
||||||
}
|
}
|
||||||
|
|
|
@ -143,7 +143,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.LoadToMemory != nil {
|
if options.LoadToMemory != nil && !options.SingleBackend {
|
||||||
for _, m := range options.LoadToMemory {
|
for _, m := range options.LoadToMemory {
|
||||||
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
|
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -17,6 +17,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
var fn func() ([]float32, error)
|
var fn func() ([]float32, error)
|
||||||
switch model := inferenceModel.(type) {
|
switch model := inferenceModel.(type) {
|
||||||
|
|
|
@ -16,6 +16,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
fn := func() error {
|
fn := func() error {
|
||||||
_, err := inferenceModel.GenerateImage(
|
_, err := inferenceModel.GenerateImage(
|
||||||
|
|
|
@ -53,6 +53,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
var protoMessages []*proto.Message
|
var protoMessages []*proto.Message
|
||||||
// if we are using the tokenizer template, we need to convert the messages to proto messages
|
// if we are using the tokenizer template, we need to convert the messages to proto messages
|
||||||
|
|
|
@ -40,10 +40,6 @@ func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...
|
||||||
grpcOpts := grpcModelOpts(c)
|
grpcOpts := grpcModelOpts(c)
|
||||||
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
|
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
|
||||||
|
|
||||||
if so.SingleBackend {
|
|
||||||
defOpts = append(defOpts, model.WithSingleActiveBackend())
|
|
||||||
}
|
|
||||||
|
|
||||||
if so.ParallelBackendRequests {
|
if so.ParallelBackendRequests {
|
||||||
defOpts = append(defOpts, model.EnableParallelRequests)
|
defOpts = append(defOpts, model.EnableParallelRequests)
|
||||||
}
|
}
|
||||||
|
@ -121,7 +117,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
triggers := make([]*pb.GrammarTrigger, 0)
|
triggers := make([]*pb.GrammarTrigger, 0)
|
||||||
for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers {
|
for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers {
|
||||||
triggers = append(triggers, &pb.GrammarTrigger{
|
triggers = append(triggers, &pb.GrammarTrigger{
|
||||||
Word: t.Word,
|
Word: t.Word,
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -161,33 +157,33 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
DisableLogStatus: c.DisableLogStatus,
|
DisableLogStatus: c.DisableLogStatus,
|
||||||
DType: c.DType,
|
DType: c.DType,
|
||||||
// LimitMMPerPrompt vLLM
|
// LimitMMPerPrompt vLLM
|
||||||
LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt),
|
LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt),
|
||||||
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
|
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
|
||||||
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
|
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
|
||||||
MMProj: c.MMProj,
|
MMProj: c.MMProj,
|
||||||
FlashAttention: c.FlashAttention,
|
FlashAttention: c.FlashAttention,
|
||||||
CacheTypeKey: c.CacheTypeK,
|
CacheTypeKey: c.CacheTypeK,
|
||||||
CacheTypeValue: c.CacheTypeV,
|
CacheTypeValue: c.CacheTypeV,
|
||||||
NoKVOffload: c.NoKVOffloading,
|
NoKVOffload: c.NoKVOffloading,
|
||||||
YarnExtFactor: c.YarnExtFactor,
|
YarnExtFactor: c.YarnExtFactor,
|
||||||
YarnAttnFactor: c.YarnAttnFactor,
|
YarnAttnFactor: c.YarnAttnFactor,
|
||||||
YarnBetaFast: c.YarnBetaFast,
|
YarnBetaFast: c.YarnBetaFast,
|
||||||
YarnBetaSlow: c.YarnBetaSlow,
|
YarnBetaSlow: c.YarnBetaSlow,
|
||||||
NGQA: c.NGQA,
|
NGQA: c.NGQA,
|
||||||
RMSNormEps: c.RMSNormEps,
|
RMSNormEps: c.RMSNormEps,
|
||||||
MLock: mmlock,
|
MLock: mmlock,
|
||||||
RopeFreqBase: c.RopeFreqBase,
|
RopeFreqBase: c.RopeFreqBase,
|
||||||
RopeScaling: c.RopeScaling,
|
RopeScaling: c.RopeScaling,
|
||||||
Type: c.ModelType,
|
Type: c.ModelType,
|
||||||
RopeFreqScale: c.RopeFreqScale,
|
RopeFreqScale: c.RopeFreqScale,
|
||||||
NUMA: c.NUMA,
|
NUMA: c.NUMA,
|
||||||
Embeddings: embeddings,
|
Embeddings: embeddings,
|
||||||
LowVRAM: lowVRAM,
|
LowVRAM: lowVRAM,
|
||||||
NGPULayers: int32(nGPULayers),
|
NGPULayers: int32(nGPULayers),
|
||||||
MMap: mmap,
|
MMap: mmap,
|
||||||
MainGPU: c.MainGPU,
|
MainGPU: c.MainGPU,
|
||||||
Threads: int32(*c.Threads),
|
Threads: int32(*c.Threads),
|
||||||
TensorSplit: c.TensorSplit,
|
TensorSplit: c.TensorSplit,
|
||||||
// AutoGPTQ
|
// AutoGPTQ
|
||||||
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
||||||
Device: c.AutoGPTQ.Device,
|
Device: c.AutoGPTQ.Device,
|
||||||
|
|
|
@ -12,10 +12,10 @@ import (
|
||||||
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
|
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
|
||||||
opts := ModelOptions(backendConfig, appConfig)
|
opts := ModelOptions(backendConfig, appConfig)
|
||||||
rerankModel, err := loader.Load(opts...)
|
rerankModel, err := loader.Load(opts...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
if rerankModel == nil {
|
if rerankModel == nil {
|
||||||
return nil, fmt.Errorf("could not load rerank model")
|
return nil, fmt.Errorf("could not load rerank model")
|
||||||
|
|
|
@ -26,10 +26,10 @@ func SoundGeneration(
|
||||||
|
|
||||||
opts := ModelOptions(backendConfig, appConfig)
|
opts := ModelOptions(backendConfig, appConfig)
|
||||||
soundGenModel, err := loader.Load(opts...)
|
soundGenModel, err := loader.Load(opts...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
if soundGenModel == nil {
|
if soundGenModel == nil {
|
||||||
return "", nil, fmt.Errorf("could not load sound generation model")
|
return "", nil, fmt.Errorf("could not load sound generation model")
|
||||||
|
|
|
@ -20,6 +20,7 @@ func TokenMetrics(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
if model == nil {
|
if model == nil {
|
||||||
return nil, fmt.Errorf("could not loadmodel model")
|
return nil, fmt.Errorf("could not loadmodel model")
|
||||||
|
|
|
@ -14,10 +14,10 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac
|
||||||
|
|
||||||
opts := ModelOptions(backendConfig, appConfig)
|
opts := ModelOptions(backendConfig, appConfig)
|
||||||
inferenceModel, err = loader.Load(opts...)
|
inferenceModel, err = loader.Load(opts...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return schema.TokenizeResponse{}, err
|
return schema.TokenizeResponse{}, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
|
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
|
||||||
predictOptions.Prompt = s
|
predictOptions.Prompt = s
|
||||||
|
|
|
@ -24,6 +24,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer ml.Close()
|
||||||
|
|
||||||
if transcriptionModel == nil {
|
if transcriptionModel == nil {
|
||||||
return nil, fmt.Errorf("could not load transcription model")
|
return nil, fmt.Errorf("could not load transcription model")
|
||||||
|
|
|
@ -23,10 +23,10 @@ func ModelTTS(
|
||||||
) (string, *proto.Result, error) {
|
) (string, *proto.Result, error) {
|
||||||
opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
|
opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
|
||||||
ttsModel, err := loader.Load(opts...)
|
ttsModel, err := loader.Load(opts...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
if ttsModel == nil {
|
if ttsModel == nil {
|
||||||
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
|
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
|
||||||
|
|
|
@ -19,6 +19,8 @@ func VAD(request *schema.VADRequest,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer ml.Close()
|
||||||
|
|
||||||
req := proto.VADRequest{
|
req := proto.VADRequest{
|
||||||
Audio: request.Audio,
|
Audio: request.Audio,
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||||
AssetsDestination: t.BackendAssetsPath,
|
AssetsDestination: t.BackendAssetsPath,
|
||||||
ExternalGRPCBackends: externalBackends,
|
ExternalGRPCBackends: externalBackends,
|
||||||
}
|
}
|
||||||
ml := model.NewModelLoader(opts.ModelPath)
|
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := ml.StopAllGRPC()
|
err := ml.StopAllGRPC()
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := config.NewBackendConfigLoader(t.ModelsPath)
|
cl := config.NewBackendConfigLoader(t.ModelsPath)
|
||||||
ml := model.NewModelLoader(opts.ModelPath)
|
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
||||||
if err := cl.LoadBackendConfigsFromPath(t.ModelsPath); err != nil {
|
if err := cl.LoadBackendConfigsFromPath(t.ModelsPath); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||||
AudioDir: outputDir,
|
AudioDir: outputDir,
|
||||||
AssetsDestination: t.BackendAssetsPath,
|
AssetsDestination: t.BackendAssetsPath,
|
||||||
}
|
}
|
||||||
ml := model.NewModelLoader(opts.ModelPath)
|
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := ml.StopAllGRPC()
|
err := ml.StopAllGRPC()
|
||||||
|
|
|
@ -21,6 +21,7 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer sl.Close()
|
||||||
|
|
||||||
vals := make([][]byte, len(input.Values))
|
vals := make([][]byte, len(input.Values))
|
||||||
for i, v := range input.Values {
|
for i, v := range input.Values {
|
||||||
|
@ -48,6 +49,7 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer sl.Close()
|
||||||
|
|
||||||
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -69,6 +71,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer sl.Close()
|
||||||
|
|
||||||
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -100,6 +103,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer sl.Close()
|
||||||
|
|
||||||
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
|
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -40,7 +40,7 @@ func TestAssistantEndpoints(t *testing.T) {
|
||||||
cl := &config.BackendConfigLoader{}
|
cl := &config.BackendConfigLoader{}
|
||||||
//configsDir := "/tmp/localai/configs"
|
//configsDir := "/tmp/localai/configs"
|
||||||
modelPath := "/tmp/localai/model"
|
modelPath := "/tmp/localai/model"
|
||||||
var ml = model.NewModelLoader(modelPath)
|
var ml = model.NewModelLoader(modelPath, false)
|
||||||
|
|
||||||
appConfig := &config.ApplicationConfig{
|
appConfig := &config.ApplicationConfig{
|
||||||
ConfigsDir: configsDir,
|
ConfigsDir: configsDir,
|
||||||
|
|
|
@ -50,11 +50,10 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||||
router.Post("/v1/vad", vadChain...)
|
router.Post("/v1/vad", vadChain...)
|
||||||
|
|
||||||
// Stores
|
// Stores
|
||||||
sl := model.NewModelLoader("")
|
router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
|
||||||
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
|
||||||
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
|
||||||
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
|
||||||
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
|
||||||
|
|
||||||
if !appConfig.DisableMetrics {
|
if !appConfig.DisableMetrics {
|
||||||
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||||
|
|
|
@ -509,7 +509,23 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) Close() {
|
||||||
|
if !ml.singletonMode {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ml.singletonLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) lockBackend() {
|
||||||
|
if !ml.singletonMode {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ml.singletonLock.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
||||||
|
ml.lockBackend() // grab the singleton lock if needed
|
||||||
|
|
||||||
o := NewOptions(opts...)
|
o := NewOptions(opts...)
|
||||||
|
|
||||||
// Return earlier if we have a model already loaded
|
// Return earlier if we have a model already loaded
|
||||||
|
@ -520,7 +536,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
||||||
return m.GRPC(o.parallelRequests, ml.wd), nil
|
return m.GRPC(o.parallelRequests, ml.wd), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ml.stopActiveBackends(o.modelID, o.singleActiveBackend)
|
ml.stopActiveBackends(o.modelID, ml.singletonMode)
|
||||||
|
|
||||||
// if a backend is defined, return the loader directly
|
// if a backend is defined, return the loader directly
|
||||||
if o.backendString != "" {
|
if o.backendString != "" {
|
||||||
|
@ -533,6 +549,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
||||||
// get backends embedded in the binary
|
// get backends embedded in the binary
|
||||||
autoLoadBackends, err := ml.ListAvailableBackends(o.assetDir)
|
autoLoadBackends, err := ml.ListAvailableBackends(o.assetDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
ml.Close() // we failed, release the lock
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -564,5 +581,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ml.Close() // make sure to release the lock in case of failure
|
||||||
|
|
||||||
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
|
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,16 +18,19 @@ import (
|
||||||
|
|
||||||
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
|
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
|
||||||
type ModelLoader struct {
|
type ModelLoader struct {
|
||||||
ModelPath string
|
ModelPath string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
models map[string]*Model
|
singletonLock sync.Mutex
|
||||||
wd *WatchDog
|
singletonMode bool
|
||||||
|
models map[string]*Model
|
||||||
|
wd *WatchDog
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModelLoader(modelPath string) *ModelLoader {
|
func NewModelLoader(modelPath string, singleActiveBackend bool) *ModelLoader {
|
||||||
nml := &ModelLoader{
|
nml := &ModelLoader{
|
||||||
ModelPath: modelPath,
|
ModelPath: modelPath,
|
||||||
models: make(map[string]*Model),
|
models: make(map[string]*Model),
|
||||||
|
singletonMode: singleActiveBackend,
|
||||||
}
|
}
|
||||||
|
|
||||||
return nml
|
return nml
|
||||||
|
|
|
@ -17,10 +17,9 @@ type Options struct {
|
||||||
|
|
||||||
externalBackends map[string]string
|
externalBackends map[string]string
|
||||||
|
|
||||||
grpcAttempts int
|
grpcAttempts int
|
||||||
grpcAttemptsDelay int
|
grpcAttemptsDelay int
|
||||||
singleActiveBackend bool
|
parallelRequests bool
|
||||||
parallelRequests bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(*Options)
|
type Option func(*Options)
|
||||||
|
@ -88,12 +87,6 @@ func WithContext(ctx context.Context) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithSingleActiveBackend() Option {
|
|
||||||
return func(o *Options) {
|
|
||||||
o.singleActiveBackend = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithModelID(id string) Option {
|
func WithModelID(id string) Option {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
o.modelID = id
|
o.modelID = id
|
||||||
|
|
|
@ -21,7 +21,7 @@ var _ = Describe("ModelLoader", func() {
|
||||||
// Setup the model loader with a test directory
|
// Setup the model loader with a test directory
|
||||||
modelPath = "/tmp/test_model_path"
|
modelPath = "/tmp/test_model_path"
|
||||||
os.Mkdir(modelPath, 0755)
|
os.Mkdir(modelPath, 0755)
|
||||||
modelLoader = model.NewModelLoader(modelPath)
|
modelLoader = model.NewModelLoader(modelPath, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
|
|
@ -70,7 +70,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
|
||||||
model.WithModel("test"),
|
model.WithModel("test"),
|
||||||
}
|
}
|
||||||
|
|
||||||
sl = model.NewModelLoader("")
|
sl = model.NewModelLoader("", false)
|
||||||
sc, err = sl.Load(storeOpts...)
|
sc, err = sl.Load(storeOpts...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(sc).ToNot(BeNil())
|
Expect(sc).ToNot(BeNil())
|
||||||
|
@ -235,7 +235,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
|
||||||
keys := [][]float32{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, {-1.0, 0.0, 0.0}}
|
keys := [][]float32{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, {-1.0, 0.0, 0.0}}
|
||||||
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}
|
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}
|
||||||
|
|
||||||
err := store.SetCols(context.Background(), sc, keys, vals);
|
err := store.SetCols(context.Background(), sc, keys, vals)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
|
_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
|
||||||
|
@ -247,7 +247,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
|
||||||
keys := [][]float32{{1.0, 0.0, 1.0}, {0.0, 2.0, 0.0}, {0.0, 0.0, -1.0}, {-1.0, 0.0, -1.0}}
|
keys := [][]float32{{1.0, 0.0, 1.0}, {0.0, 2.0, 0.0}, {0.0, 0.0, -1.0}, {-1.0, 0.0, -1.0}}
|
||||||
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}
|
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}
|
||||||
|
|
||||||
err := store.SetCols(context.Background(), sc, keys, vals);
|
err := store.SetCols(context.Background(), sc, keys, vals)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
|
_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
|
||||||
|
@ -314,7 +314,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
|
||||||
|
|
||||||
normalize(keys[6:])
|
normalize(keys[6:])
|
||||||
|
|
||||||
err := store.SetCols(context.Background(), sc, keys, vals);
|
err := store.SetCols(context.Background(), sc, keys, vals)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
expectTriangleEq(keys, vals)
|
expectTriangleEq(keys, vals)
|
||||||
|
@ -341,7 +341,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
|
||||||
c += 1
|
c += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
err := store.SetCols(context.Background(), sc, keys, vals);
|
err := store.SetCols(context.Background(), sc, keys, vals)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
expectTriangleEq(keys, vals)
|
expectTriangleEq(keys, vals)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue