mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-29 14:14:59 +00:00
Add machine tag option, add extraUsage option, grpc-server -> proto -> endpoint extraUsage data is broken for now
Signed-off-by: mintyleaf <mintyleafdev@gmail.com>
This commit is contained in:
parent
20edd44463
commit
f040aa46a3
19 changed files with 192 additions and 46 deletions
|
@ -159,6 +159,10 @@ message Reply {
|
|||
bytes message = 1;
|
||||
int32 tokens = 2;
|
||||
int32 prompt_tokens = 3;
|
||||
int32 timing_prompt_tokens = 4;
|
||||
int32 timing_predicted_tokens = 5;
|
||||
double timing_prompt_processing = 6;
|
||||
double timing_token_generation = 7;
|
||||
}
|
||||
|
||||
message ModelOptions {
|
||||
|
@ -348,4 +352,4 @@ message StatusResponse {
|
|||
message Message {
|
||||
string role = 1;
|
||||
string content = 2;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2414,6 +2414,15 @@ public:
|
|||
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
int32_t timing_prompt_tokens = result.result_json.value("timings", json{}).value("prompt_n", 0);
|
||||
reply.set_timing_prompt_tokens(timing_prompt_tokens);
|
||||
int32_t timing_predicted_tokens = result.result_json.value("timings", json{}).value("predicted_n", 0);
|
||||
reply.set_timing_predicted_tokens(timing_predicted_tokens);
|
||||
double timing_prompt_processing = result.result_json.value("timings", json{}).value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = result.result_json.value("timings", json{}).value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
|
||||
// Log Request Correlation Id
|
||||
LOG_VERBOSE("correlation:", {
|
||||
{ "id", data["correlation_id"] }
|
||||
|
@ -2454,6 +2463,15 @@ public:
|
|||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
reply->set_message(completion_text);
|
||||
|
||||
int32_t timing_prompt_tokens = result.result_json.value("timings", json{}).value("prompt_n", 0);
|
||||
reply->set_timing_prompt_tokens(timing_prompt_tokens);
|
||||
int32_t timing_predicted_tokens = result.result_json.value("timings", json{}).value("predicted_n", 0);
|
||||
reply->set_timing_predicted_tokens(timing_predicted_tokens);
|
||||
double timing_prompt_processing = result.result_json.value("timings", json{}).value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = result.result_json.value("timings", json{}).value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
@ -27,8 +27,12 @@ type LLMResponse struct {
|
|||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
Prompt int
|
||||
Completion int
|
||||
Prompt int
|
||||
Completion int
|
||||
TimingPromptTokens int
|
||||
TimingPredictedTokens int
|
||||
TimingPromptProcessing float64
|
||||
TimingTokenGeneration float64
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
|
@ -123,6 +127,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||
|
||||
tokenUsage.Prompt = int(reply.PromptTokens)
|
||||
tokenUsage.Completion = int(reply.Tokens)
|
||||
tokenUsage.TimingPredictedTokens = int(reply.TimingPredictedTokens)
|
||||
tokenUsage.TimingPromptTokens = int(reply.TimingPromptTokens)
|
||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||
|
||||
for len(partialRune) > 0 {
|
||||
r, size := utf8.DecodeRune(partialRune)
|
||||
|
@ -157,6 +165,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||
if tokenUsage.Completion == 0 {
|
||||
tokenUsage.Completion = int(reply.Tokens)
|
||||
}
|
||||
|
||||
tokenUsage.TimingPredictedTokens = int(reply.TimingPredictedTokens)
|
||||
tokenUsage.TimingPromptTokens = int(reply.TimingPromptTokens)
|
||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||
|
||||
return LLMResponse{
|
||||
Response: string(reply.Message),
|
||||
Usage: tokenUsage,
|
||||
|
|
|
@ -71,6 +71,7 @@ type RunCMD struct {
|
|||
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
|
||||
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
|
||||
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
|
||||
MachineTag string `env:"LOCALAI_MACHINE_TAG" help:"TODO: write a help string"`
|
||||
}
|
||||
|
||||
func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
|
@ -107,6 +108,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
|
||||
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
|
||||
config.WithLoadToMemory(r.LoadToMemory),
|
||||
config.WithMachineTag(r.MachineTag),
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
|
@ -65,6 +66,8 @@ type ApplicationConfig struct {
|
|||
ModelsURL []string
|
||||
|
||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||
|
||||
MachineTag string
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
|
@ -94,6 +97,16 @@ func WithModelPath(path string) AppOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithMachineTag(tag string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
if tag == "" {
|
||||
hostname, _ := os.Hostname()
|
||||
tag = hostname
|
||||
}
|
||||
o.MachineTag = tag
|
||||
}
|
||||
}
|
||||
|
||||
func WithCors(b bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.CORS = b
|
||||
|
|
|
@ -19,6 +19,8 @@ import (
|
|||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
|
||||
req := new(schema.JINARerankRequest)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
|
||||
input := new(schema.TTSRequest)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
|
||||
input := new(schema.VADRequest)
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ type AssistantRequest struct {
|
|||
// @Router /v1/assistants [post]
|
||||
func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
request := new(AssistantRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||
|
@ -137,6 +138,7 @@ func generateRandomID() int64 {
|
|||
// @Router /v1/assistants [get]
|
||||
func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
// Because we're altering the existing assistants list we should just duplicate it for now.
|
||||
returnAssistants := Assistants
|
||||
// Parse query parameters
|
||||
|
@ -246,6 +248,7 @@ func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelNam
|
|||
// @Router /v1/assistants/{assistant_id} [delete]
|
||||
func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
|
@ -278,6 +281,7 @@ func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
|
|||
// @Router /v1/assistants/{assistant_id} [get]
|
||||
func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
|
@ -307,6 +311,7 @@ var (
|
|||
|
||||
func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
request := new(schema.AssistantFileRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
|
@ -353,6 +358,7 @@ func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelL
|
|||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
assistantID := c.Params("assistant_id")
|
||||
if assistantID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||
|
@ -410,6 +416,7 @@ func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelL
|
|||
|
||||
func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
request := new(AssistantRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||
|
@ -449,6 +456,7 @@ func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
|
|||
|
||||
func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
assistantID := c.Params("assistant_id")
|
||||
fileId := c.Params("file_id")
|
||||
if assistantID == "" {
|
||||
|
@ -503,6 +511,7 @@ func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
|
|||
|
||||
func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
assistantID := c.Params("assistant_id")
|
||||
fileId := c.Params("file_id")
|
||||
if assistantID == "" {
|
||||
|
|
|
@ -30,7 +30,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
|
@ -40,18 +40,26 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
}
|
||||
responses <- initialMessage
|
||||
|
||||
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
|
||||
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
|
||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
|
@ -59,7 +67,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
})
|
||||
close(responses)
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
result := ""
|
||||
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
@ -90,6 +98,17 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
log.Error().Err(err).Msg("error handling question")
|
||||
return
|
||||
}
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
|
||||
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
|
||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
|
@ -97,11 +116,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
|
@ -160,6 +175,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", startupOptions.MachineTag)
|
||||
|
||||
textContentToReturn = ""
|
||||
id = uuid.New().String()
|
||||
created = int(time.Now().Unix())
|
||||
|
@ -170,6 +187,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
}
|
||||
c.Set("X-Correlation-ID", correlationID)
|
||||
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""
|
||||
|
||||
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
@ -311,6 +331,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("LocalAI-Machine-Tag", startupOptions.MachineTag)
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
|
@ -319,9 +340,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
if !shouldUseFn {
|
||||
go process(predInput, input, config, ml, responses)
|
||||
go process(predInput, input, config, ml, responses, extraUsage)
|
||||
} else {
|
||||
go processTools(noActionName, predInput, input, config, ml, responses)
|
||||
go processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
|
||||
}
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
@ -449,6 +470,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
|
||||
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
|
||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
fmt.Println(tokenUsage)
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
|
@ -456,11 +489,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
|
|
@ -30,8 +30,19 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
|
||||
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
|
||||
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
|
||||
}
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
|
@ -43,11 +54,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
|
@ -58,8 +65,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
// Add Correlation
|
||||
c.Set("X-Correlation-ID", id)
|
||||
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""
|
||||
|
||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
@ -113,7 +125,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, ml, responses)
|
||||
go process(predInput, input, config, ml, responses, extraUsage)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
|
@ -170,11 +182,24 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
return err
|
||||
}
|
||||
|
||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
totalTokenUsage.TimingPredictedTokens += tokenUsage.TimingPredictedTokens
|
||||
totalTokenUsage.TimingPromptTokens += tokenUsage.TimingPromptTokens
|
||||
totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration
|
||||
totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingPredictedTokens = totalTokenUsage.TimingPredictedTokens
|
||||
usage.TimingPromptTokens = totalTokenUsage.TimingPromptTokens
|
||||
usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
|
@ -182,11 +207,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
|||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
|
|
|
@ -25,6 +25,11 @@ import (
|
|||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""
|
||||
|
||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
@ -61,8 +66,24 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
|
||||
totalTokenUsage.TimingPredictedTokens += tokenUsage.TimingPredictedTokens
|
||||
totalTokenUsage.TimingPromptTokens += tokenUsage.TimingPromptTokens
|
||||
totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration
|
||||
totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
}
|
||||
if extraUsage {
|
||||
usage.TimingPredictedTokens = totalTokenUsage.TimingPredictedTokens
|
||||
usage.TimingPromptTokens = totalTokenUsage.TimingPromptTokens
|
||||
usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration
|
||||
usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
@ -72,11 +93,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "edit",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
model, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
|
|
@ -23,6 +23,7 @@ const UploadedFilesFile = "uploadedFiles.json"
|
|||
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -82,6 +83,7 @@ func getNextFileId() int64 {
|
|||
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
var listFiles schema.ListFiles
|
||||
|
||||
purpose := c.Query("purpose")
|
||||
|
@ -120,6 +122,7 @@ func getFileFromRequest(c *fiber.Ctx) (*schema.File, error) {
|
|||
// @Router /v1/files/{file_id} [get]
|
||||
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
|
@ -142,6 +145,7 @@ type DeleteStatus struct {
|
|||
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
|
@ -179,6 +183,7 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
|
|||
// GetFilesContentsEndpoint
|
||||
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
|
|
|
@ -66,6 +66,7 @@ func downloadFile(url string) (string, error) {
|
|||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
m, input, err := readRequest(c, cl, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
|
|
@ -12,8 +12,10 @@ import (
|
|||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
|
||||
m, input, err := readRequest(c, cl, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
|
|
@ -130,6 +130,6 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
|||
}
|
||||
|
||||
// List models
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
}
|
||||
|
|
|
@ -23,6 +23,11 @@ type OpenAIUsage struct {
|
|||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
// Extra timing data, disabled by default as is't not a part of OpenAI specification
|
||||
TimingPromptTokens int `json:"timing_prompt_tokens,omitempty"`
|
||||
TimingPredictedTokens int `json:"timing_predicted_tokens,omitempty"`
|
||||
TimingPromptProcessing float64 `json:"timing_prompt_processing,omitempty"`
|
||||
TimingTokenGeneration float64 `json:"timing_token_generation,omitempty"`
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue