mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-25 13:04:59 +00:00
Add support for stablelm (#48)
Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
parent
142bcd66ca
commit
f816dfae65
7 changed files with 106 additions and 13 deletions
32
api/api.go
32
api/api.go
|
@ -75,6 +75,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
var model *llama.LLama
|
||||
var gptModel *gptj.GPTJ
|
||||
var gpt2Model *gpt2.GPT2
|
||||
var stableLMModel *gpt2.StableLM
|
||||
|
||||
input := new(OpenAIRequest)
|
||||
// Get input data from the request body
|
||||
|
@ -99,7 +100,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
}
|
||||
|
||||
// Try to load the model with both
|
||||
var llamaerr, gpt2err, gptjerr error
|
||||
var llamaerr, gpt2err, gptjerr, stableerr error
|
||||
llamaOpts := []llama.ModelOption{}
|
||||
if ctx != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
|
||||
|
@ -115,7 +116,10 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
if gptjerr != nil {
|
||||
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
|
||||
if gpt2err != nil {
|
||||
return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors
|
||||
stableLMModel, stableerr = loader.LoadStableLMModel(modelFile)
|
||||
if stableerr != nil {
|
||||
return fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -182,6 +186,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
|
||||
var predFunc func() (string, error)
|
||||
switch {
|
||||
case stableLMModel != nil:
|
||||
predFunc = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt2.PredictOption{
|
||||
gpt2.SetTemperature(temperature),
|
||||
gpt2.SetTopP(topP),
|
||||
gpt2.SetTopK(topK),
|
||||
gpt2.SetTokens(tokens),
|
||||
gpt2.SetThreads(threads),
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
|
||||
}
|
||||
|
||||
return stableLMModel.Predict(
|
||||
predInput,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case gpt2Model != nil:
|
||||
predFunc = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue