Enhancements (#34)

Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
Ettore Di Giacinto 2023-04-19 17:10:29 +02:00 committed by GitHub
parent a9a875ee2b
commit 7fec26f5d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 226 additions and 78 deletions

View file

@ -7,6 +7,7 @@ import (
model "github.com/go-skynet/llama-cli/pkg/model"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
@ -60,13 +61,16 @@ type OpenAIRequest struct {
Batch int `json:"batch"`
F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"`
Seed int `json:"seed"`
}
// https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var err error
var model *llama.LLama
var gptModel *gptj.GPTJ
input := new(OpenAIRequest)
// Get input data from the request body
@ -77,9 +81,22 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
if input.Model == "" {
return fmt.Errorf("no model specified")
} else {
model, err = loader.LoadModel(input.Model)
if err != nil {
return err
// Try to load the model with both
var llamaerr error
llamaOpts := []llama.ModelOption{}
if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
}
if f16 {
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
model, llamaerr = loader.LoadLLaMAModel(input.Model, llamaOpts...)
if llamaerr != nil {
gptModel, err = loader.LoadGPTJModel(input.Model)
if err != nil {
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
}
}
}
@ -146,32 +163,70 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
n = 1
}
var predFunc func() (string, error)
switch {
case gptModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gptj.PredictOption{
gptj.SetTemperature(temperature),
gptj.SetTopP(topP),
gptj.SetTopK(topK),
gptj.SetTokens(tokens),
gptj.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gptj.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gptj.SetSeed(input.Seed))
}
return gptModel.Predict(
predInput,
predictOptions...,
)
}
case model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []llama.PredictOption{
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}
if input.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}
if input.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}
if input.Seed != 0 {
predictOptions = append(predictOptions, llama.SetSeed(input.Seed))
}
return model.Predict(
predInput,
predictOptions...,
)
}
}
for i := 0; i < n; i++ {
// Generate the prediction using the language model
predictOptions := []llama.PredictOption{
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
}
var prediction string
if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}
if input.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}
if input.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}
prediction, err := model.Predict(
predInput,
predictOptions...,
)
prediction, err := predFunc()
if err != nil {
return err
}
@ -179,6 +234,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
if input.Echo {
prediction = predInput + prediction
}
if chat {
result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
} else {
@ -194,7 +250,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
}
}
func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error {
app := fiber.New()
// Default middleware config
@ -207,8 +263,8 @@ func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
var mumutex = &sync.Mutex{}
// openAI compatible API endpoint
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, mutex, mumutex, mu))
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, mutex, mumutex, mu))
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mutex, mumutex, mu))
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mutex, mumutex, mu))
app.Get("/v1/models", func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {