mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
parent
d517a54e28
commit
1c4fbaae20
6 changed files with 124 additions and 20 deletions
38
api/api.go
38
api/api.go
|
@ -8,6 +8,7 @@ import (
|
|||
"sync"
|
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -73,6 +74,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
var err error
|
||||
var model *llama.LLama
|
||||
var gptModel *gptj.GPTJ
|
||||
var gpt2Model *gpt2.GPT2
|
||||
|
||||
input := new(OpenAIRequest)
|
||||
// Get input data from the request body
|
||||
|
@ -97,7 +99,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
}
|
||||
|
||||
// Try to load the model with both
|
||||
var llamaerr error
|
||||
var llamaerr, gpt2err, gptjerr error
|
||||
llamaOpts := []llama.ModelOption{}
|
||||
if ctx != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
|
||||
|
@ -106,11 +108,15 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
|
||||
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
|
||||
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
|
||||
if llamaerr != nil {
|
||||
gptModel, err = loader.LoadGPTJModel(modelFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
|
||||
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
|
||||
if gptjerr != nil {
|
||||
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
|
||||
if gpt2err != nil {
|
||||
return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,6 +182,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
|
||||
var predFunc func() (string, error)
|
||||
switch {
|
||||
case gpt2Model != nil:
|
||||
predFunc = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt2.PredictOption{
|
||||
gpt2.SetTemperature(temperature),
|
||||
gpt2.SetTopP(topP),
|
||||
gpt2.SetTopK(topK),
|
||||
gpt2.SetTokens(tokens),
|
||||
gpt2.SetThreads(threads),
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
|
||||
}
|
||||
|
||||
return gpt2Model.Predict(
|
||||
predInput,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case gptModel != nil:
|
||||
predFunc = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue