mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-23 20:14:59 +00:00
feat: allow to specify default backend for model (#156)
Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
parent
70caf9bf8c
commit
1ae7150810
7 changed files with 97 additions and 77 deletions
|
@ -168,13 +168,6 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
// TODO: This needs refactoring, it's really bad to have it in here
|
||||
// Check if we have a GPTStable model loaded instead - if we do we return an error so the API tries with StableLM
|
||||
if _, ok := ml.gptstablelmmodels[modelName]; ok {
|
||||
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
|
||||
return nil, fmt.Errorf("this model is a GPTStableLM one")
|
||||
}
|
||||
|
||||
// Load the model and keep it in memory for later use
|
||||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
|
||||
|
@ -207,17 +200,6 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
// TODO: This needs refactoring, it's really bad to have it in here
|
||||
// Check if we have a GPT2 model loaded instead - if we do we return an error so the API tries with GPT2
|
||||
if _, ok := ml.gpt2models[modelName]; ok {
|
||||
log.Debug().Msgf("Model is GPT2: %s", modelName)
|
||||
return nil, fmt.Errorf("this model is a GPT2 one")
|
||||
}
|
||||
if _, ok := ml.gptstablelmmodels[modelName]; ok {
|
||||
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
|
||||
return nil, fmt.Errorf("this model is a GPTStableLM one")
|
||||
}
|
||||
|
||||
// Load the model and keep it in memory for later use
|
||||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
|
||||
|
@ -252,21 +234,6 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
|
|||
return m, nil
|
||||
}
|
||||
|
||||
// TODO: This needs refactoring, it's really bad to have it in here
|
||||
// Check if we have a GPTJ model loaded instead - if we do we return an error so the API tries with GPTJ
|
||||
if _, ok := ml.gptmodels[modelName]; ok {
|
||||
log.Debug().Msgf("Model is GPTJ: %s", modelName)
|
||||
return nil, fmt.Errorf("this model is a GPTJ one")
|
||||
}
|
||||
if _, ok := ml.gpt2models[modelName]; ok {
|
||||
log.Debug().Msgf("Model is GPT2: %s", modelName)
|
||||
return nil, fmt.Errorf("this model is a GPT2 one")
|
||||
}
|
||||
if _, ok := ml.gptstablelmmodels[modelName]; ok {
|
||||
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
|
||||
return nil, fmt.Errorf("this model is a GPTStableLM one")
|
||||
}
|
||||
|
||||
// Load the model and keep it in memory for later use
|
||||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue