Add support for stablelm (#48)

Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
Ettore Di Giacinto 2023-04-21 00:06:55 +02:00 committed by GitHub
parent 142bcd66ca
commit f816dfae65
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 106 additions and 13 deletions

View file

@ -75,6 +75,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var model *llama.LLama
var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
var stableLMModel *gpt2.StableLM
input := new(OpenAIRequest)
// Get input data from the request body
@ -99,7 +100,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
}
// Try to load the model with both
var llamaerr, gpt2err, gptjerr error
var llamaerr, gpt2err, gptjerr, stableerr error
llamaOpts := []llama.ModelOption{}
if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
@ -115,7 +116,10 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors
stableLMModel, stableerr = loader.LoadStableLMModel(modelFile)
if stableerr != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors
}
}
}
}
@ -182,6 +186,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var predFunc func() (string, error)
switch {
case stableLMModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return stableLMModel.Predict(
predInput,
predictOptions...,
)
}
case gpt2Model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model