Add support for cerebras (#45)

Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
Ettore Di Giacinto 2023-04-20 19:33:36 +02:00 committed by GitHub
parent d517a54e28
commit 1c4fbaae20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 124 additions and 20 deletions

View file

@ -8,6 +8,7 @@ import (
"sync"
model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2"
@ -73,6 +74,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var err error
var model *llama.LLama
var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
input := new(OpenAIRequest)
// Get input data from the request body
@ -97,7 +99,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
}
// Try to load the model with both
var llamaerr error
var llamaerr, gpt2err, gptjerr error
llamaOpts := []llama.ModelOption{}
if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
@ -106,11 +108,15 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
if llamaerr != nil {
gptModel, err = loader.LoadGPTJModel(modelFile)
if err != nil {
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors
}
}
}
@ -176,6 +182,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var predFunc func() (string, error)
switch {
case gpt2Model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return gpt2Model.Predict(
predInput,
predictOptions...,
)
}
case gptModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model