mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
Add stopwords, debug mode, and other API enhancements (#54)
Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
parent
4b7e83056d
commit
5cba71de70
5 changed files with 36 additions and 13 deletions
38
api/api.go
38
api/api.go
|
@ -48,6 +48,8 @@ type OpenAIRequest struct {
|
|||
// Prompt is read only by completion API calls
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
Stop string `json:"stop"`
|
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages"`
|
||||
|
||||
|
@ -61,15 +63,17 @@ type OpenAIRequest struct {
|
|||
N int `json:"n"`
|
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch"`
|
||||
F16 bool `json:"f16kv"`
|
||||
IgnoreEOS bool `json:"ignore_eos"`
|
||||
Batch int `json:"batch"`
|
||||
F16 bool `json:"f16kv"`
|
||||
IgnoreEOS bool `json:"ignore_eos"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty"`
|
||||
Keep int `json:"n_keep"`
|
||||
|
||||
Seed int `json:"seed"`
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||
func openAIEndpoint(chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var err error
|
||||
var model *llama.LLama
|
||||
|
@ -269,6 +273,22 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
|
|||
llama.SetThreads(threads),
|
||||
}
|
||||
|
||||
if debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
if input.Stop != "" {
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(input.Stop))
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(input.RepeatPenalty))
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(input.Keep))
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
|
||||
}
|
||||
|
@ -341,7 +361,7 @@ func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
|||
}
|
||||
}
|
||||
|
||||
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error {
|
||||
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool, debug bool) error {
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
// Override default error handler
|
||||
|
@ -371,11 +391,11 @@ func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f
|
|||
var mumutex = &sync.Mutex{}
|
||||
|
||||
// openAI compatible API endpoint
|
||||
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
app.Post("/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
app.Post("/v1/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
app.Post("/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
|
||||
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
app.Post("/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
app.Post("/v1/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
app.Post("/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu))
|
||||
|
||||
app.Get("/v1/models", listModels(loader))
|
||||
app.Get("/models", listModels(loader))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue