Add stopwords, debug mode, and other API enhancements (#54)

Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
Ettore Di Giacinto 2023-04-21 19:46:59 +02:00 committed by GitHub
parent 4b7e83056d
commit 5cba71de70
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 13 deletions

View file

@ -48,6 +48,8 @@ type OpenAIRequest struct {
// Prompt is read only by completion API calls
Prompt string `json:"prompt"`
Stop string `json:"stop"`
// Messages is read only by chat/completion API calls
Messages []Message `json:"messages"`
@ -61,15 +63,17 @@ type OpenAIRequest struct {
N int `json:"n"`
// Custom parameters - not present in the OpenAI API
Batch int `json:"batch"`
F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"`
Batch int `json:"batch"`
F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"`
RepeatPenalty float64 `json:"repeat_penalty"`
Keep int `json:"n_keep"`
Seed int `json:"seed"`
}
// https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
func openAIEndpoint(chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var err error
var model *llama.LLama
@ -269,6 +273,22 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
llama.SetThreads(threads),
}
if debug {
predictOptions = append(predictOptions, llama.Debug)
}
if input.Stop != "" {
predictOptions = append(predictOptions, llama.SetStopWords(input.Stop))
}
if input.RepeatPenalty != 0 {
predictOptions = append(predictOptions, llama.SetPenalty(input.RepeatPenalty))
}
if input.Keep != 0 {
predictOptions = append(predictOptions, llama.SetNKeep(input.Keep))
}
if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}
@ -341,7 +361,7 @@ func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error {
}
}
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error {
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool, debug bool) error {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
// Override default error handler
@ -371,11 +391,11 @@ func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f
var mumutex = &sync.Mutex{}
// openAI compatible API endpoint
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/v1/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/v1/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Get("/v1/models", listModels(loader))
app.Get("/models", listModels(loader))