fix: drop racy code, refactor and group API schema (#931)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-08-20 14:04:45 +02:00 committed by GitHub
parent 28db83e17b
commit cc060a283d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 239 additions and 317 deletions

View file

@ -1,115 +0,0 @@
package openai
import (
"context"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grammar"
)
// APIError provides error information returned by the OpenAI API.
type APIError struct {
Code any `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
}
type ErrorResponse struct {
Error *APIError `json:"error,omitempty"`
}
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Item struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object,omitempty"`
// Images
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
type OpenAIResponse struct {
Created int `json:"created,omitempty"`
Object string `json:"object,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
Data []Item `json:"data,omitempty"`
Usage OpenAIUsage `json:"usage"`
}
type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason,omitempty"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
Text string `json:"text,omitempty"`
}
type Message struct {
// The message role
Role string `json:"role,omitempty" yaml:"role"`
// The message content
Content *string `json:"content" yaml:"content"`
// A result of a function call
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
}
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
}
type OpenAIRequest struct {
config.PredictionOptions
Context context.Context
Cancel context.CancelFunc
// whisper
File string `json:"file" validate:"required"`
//whisper/image
ResponseFormat string `json:"response_format"`
// image
Size string `json:"size"`
// Prompt is read only by completion/image API calls
Prompt interface{} `json:"prompt" yaml:"prompt"`
// Edit endpoint
Instruction string `json:"instruction" yaml:"instruction"`
Input interface{} `json:"input" yaml:"input"`
Stop interface{} `json:"stop" yaml:"stop"`
// Messages is read only by chat/completion API calls
Messages []Message `json:"messages" yaml:"messages"`
// A list of available functions to call
Functions []grammar.Function `json:"functions" yaml:"functions"`
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
Stream bool `json:"stream"`
// Image (not supported by OpenAI)
Mode int `json:"mode"`
Step int `json:"step"`
// A grammar to constrain the LLM output
Grammar string `json:"grammar" yaml:"grammar"`
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
Backend string `json:"backend" yaml:"backend"`
// AutoGPTQ
ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`
}

View file

@ -10,6 +10,7 @@ import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
@ -21,20 +22,20 @@ import (
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
emptyMessage := ""
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
initialMessage := OpenAIResponse{
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Role: "assistant", Content: &emptyMessage}}},
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := OpenAIResponse{
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: OpenAIUsage{
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
@ -236,13 +237,13 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
if toStream {
responses := make(chan OpenAIResponse)
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &OpenAIUsage{}
usage := &schema.OpenAIUsage{}
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
@ -259,13 +260,13 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
w.Flush()
}
resp := &OpenAIResponse{
resp := &schema.OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{
Choices: []schema.Choice{
{
FinishReason: "stop",
Index: 0,
Delta: &Message{Content: &emptyMessage},
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
@ -279,7 +280,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return nil
}
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
@ -313,7 +314,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
message = backend.Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}})
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
return
}
}
@ -336,28 +337,28 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &fineTunedResponse}})
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
// otherwise reply with the function call
*c = append(*c, Choice{
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &Message{Role: "assistant", FunctionCall: ss},
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
})
}
return
}
*c = append(*c, Choice{FinishReason: "stop", Index: 0, Message: &Message{Role: "assistant", Content: &s}})
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err
}
resp := &OpenAIResponse{
resp := &schema.OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: OpenAIUsage{
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,

View file

@ -10,6 +10,7 @@ import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
@ -18,18 +19,18 @@ import (
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := OpenAIResponse{
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{
Choices: []schema.Choice{
{
Index: 0,
Text: s,
},
},
Object: "text_completion",
Usage: OpenAIUsage{
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
@ -90,7 +91,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
responses := make(chan OpenAIResponse)
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses)
@ -106,9 +107,9 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
w.Flush()
}
resp := &OpenAIResponse{
resp := &schema.OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{
Choices: []schema.Choice{
{
Index: 0,
FinishReason: "stop",
@ -125,7 +126,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
return nil
}
var result []Choice
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
@ -140,9 +141,10 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
r, tokenUsage, err := ComputeChoices(
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return err
}
@ -153,11 +155,11 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
result = append(result, r...)
}
resp := &OpenAIResponse{
resp := &schema.OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: OpenAIUsage{
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,

View file

@ -7,8 +7,10 @@ import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@ -32,7 +34,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
templateFile = config.TemplateConfig.Edit
}
var result []Choice
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for _, i := range config.InputStrings {
@ -47,8 +49,8 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {
return err
@ -60,11 +62,11 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
result = append(result, r...)
}
resp := &OpenAIResponse{
resp := &schema.OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: OpenAIUsage{
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,

View file

@ -6,6 +6,8 @@ import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
@ -25,7 +27,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []Item{}
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
@ -38,7 +40,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
if err != nil {
return err
}
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
@ -52,10 +54,10 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
if err != nil {
return err
}
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
resp := &OpenAIResponse{
resp := &schema.OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",

View file

@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/api/schema"
"os"
"path/filepath"
"strconv"
@ -100,7 +101,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
b64JSON = true
}
// src and clip_skip
var result []Item
var result []schema.Item
for _, i := range config.PromptStrings {
n := input.N
if input.N == 0 {
@ -155,7 +156,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
return err
}
item := &Item{}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
@ -173,7 +174,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
}
}
resp := &OpenAIResponse{
resp := &schema.OpenAIResponse{
Data: result,
}

View file

@ -4,12 +4,20 @@ import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string, backend.TokenUsage) bool) ([]Choice, backend.TokenUsage, error) {
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.Config,
o *options.Option,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
n := req.N // number of completions to return
result := []Choice{}
result := []schema.Choice{}
if n == 0 {
n = 1

View file

@ -4,6 +4,7 @@ import (
"regexp"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
@ -16,7 +17,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []OpenAIModel{}
dataModels := []schema.OpenAIModel{}
var filterFn func(name string) bool
filter := c.Query("filter")
@ -45,7 +46,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
}
if filterFn(c.Name) {
dataModels = append(dataModels, OpenAIModel{ID: c.Name, Object: "model"})
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
}
}
@ -53,13 +54,13 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
}
return c.JSON(struct {
Object string `json:"object"`
Data []OpenAIModel `json:"data"`
Object string `json:"object"`
Data []schema.OpenAIModel `json:"data"`
}{
Object: "list",
Data: dataModels,

View file

@ -10,14 +10,15 @@ import (
config "github.com/go-skynet/LocalAI/api/config"
options "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *OpenAIRequest, error) {
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *schema.OpenAIRequest, error) {
loader := o.Loader
input := new(OpenAIRequest)
input := new(schema.OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
@ -60,7 +61,7 @@ func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *Open
return modelFile, input, nil
}
func updateConfig(config *config.Config, input *OpenAIRequest) {
func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
@ -218,7 +219,7 @@ func updateConfig(config *config.Config, input *OpenAIRequest) {
}
}
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) {
func readConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")