diff --git a/backend/backend.proto b/backend/backend.proto index 0a341ca2..df21cd87 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -159,6 +159,10 @@ message Reply { bytes message = 1; int32 tokens = 2; int32 prompt_tokens = 3; + int32 timing_prompt_tokens = 4; + int32 timing_predicted_tokens = 5; + double timing_prompt_processing = 6; + double timing_token_generation = 7; } message ModelOptions { @@ -348,4 +352,4 @@ message StatusResponse { message Message { string role = 1; string content = 2; -} \ No newline at end of file +} diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 7632aebc..16b4e469 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -2414,6 +2414,15 @@ public: int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); reply.set_prompt_tokens(tokens_evaluated); + int32_t timing_prompt_tokens = result.result_json.value("timings", json{}).value("prompt_n", 0); + reply.set_timing_prompt_tokens(timing_prompt_tokens); + int32_t timing_predicted_tokens = result.result_json.value("timings", json{}).value("predicted_n", 0); + reply.set_timing_predicted_tokens(timing_predicted_tokens); + double timing_prompt_processing = result.result_json.value("timings", json{}).value("prompt_ms", 0.0); + reply.set_timing_prompt_processing(timing_prompt_processing); + double timing_token_generation = result.result_json.value("timings", json{}).value("predicted_ms", 0.0); + reply.set_timing_token_generation(timing_token_generation); + // Log Request Correlation Id LOG_VERBOSE("correlation:", { { "id", data["correlation_id"] } @@ -2454,6 +2463,15 @@ public: reply->set_prompt_tokens(tokens_evaluated); reply->set_tokens(tokens_predicted); reply->set_message(completion_text); + + int32_t timing_prompt_tokens = result.result_json.value("timings", json{}).value("prompt_n", 0); + reply->set_timing_prompt_tokens(timing_prompt_tokens); + int32_t timing_predicted_tokens = result.result_json.value("timings", json{}).value("predicted_n", 0); + reply->set_timing_predicted_tokens(timing_predicted_tokens); + double timing_prompt_processing = result.result_json.value("timings", json{}).value("prompt_ms", 0.0); + reply->set_timing_prompt_processing(timing_prompt_processing); + double timing_token_generation = result.result_json.value("timings", json{}).value("predicted_ms", 0.0); + reply->set_timing_token_generation(timing_token_generation); } else { diff --git a/core/backend/llm.go b/core/backend/llm.go index 9a4d0d46..378159aa 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -27,8 +27,12 @@ type LLMResponse struct { } type TokenUsage struct { - Prompt int - Completion int + Prompt int + Completion int + TimingPromptTokens int + TimingPredictedTokens int + TimingPromptProcessing float64 + TimingTokenGeneration float64 } func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { @@ -123,6 +127,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im tokenUsage.Prompt = int(reply.PromptTokens) tokenUsage.Completion = int(reply.Tokens) + tokenUsage.TimingPredictedTokens = int(reply.TimingPredictedTokens) + tokenUsage.TimingPromptTokens = int(reply.TimingPromptTokens) + tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration + tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing for len(partialRune) > 0 { r, size := utf8.DecodeRune(partialRune) @@ -157,6 +165,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im if tokenUsage.Completion == 0 { tokenUsage.Completion = int(reply.Tokens) } + + tokenUsage.TimingPredictedTokens = int(reply.TimingPredictedTokens) + tokenUsage.TimingPromptTokens = int(reply.TimingPromptTokens) + tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration + tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing + return LLMResponse{ Response: string(reply.Message), Usage: tokenUsage, diff --git a/core/cli/run.go b/core/cli/run.go index a0e16155..e31f3ce0 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -71,6 +71,7 @@ type RunCMD struct { Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"` DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"` LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"` + MachineTag string `env:"LOCALAI_MACHINE_TAG" help:"TODO: write a help string"` } func (r *RunCMD) Run(ctx *cliContext.Context) error { @@ -107,6 +108,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints), config.WithP2PNetworkID(r.Peer2PeerNetworkID), config.WithLoadToMemory(r.LoadToMemory), + config.WithMachineTag(r.MachineTag), } if r.DisableMetricsEndpoint { diff --git a/core/config/application_config.go b/core/config/application_config.go index 3f321e70..be3d1230 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -4,6 +4,7 @@ import ( "context" "embed" "encoding/json" + "os" "regexp" "time" @@ -65,6 +66,8 @@ type ApplicationConfig struct { ModelsURL []string WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration + + MachineTag string } type AppOption func(*ApplicationConfig) @@ -94,6 +97,16 @@ func WithModelPath(path string) AppOption { } } +func WithMachineTag(tag string) AppOption { + return func(o *ApplicationConfig) { + if tag == "" { + hostname, _ := os.Hostname() + tag = hostname + } + o.MachineTag = tag + } +} + func WithCors(b bool) AppOption { return func(o *ApplicationConfig) { o.CORS = b diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 58c3972d..6cd4496c 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -19,6 +19,8 @@ import ( // @Router /v1/rerank [post] func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) + req := new(schema.JINARerankRequest) if err := c.BodyParser(req); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 7c73c633..d9c36b79 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -24,6 +24,7 @@ import ( // @Router /tts [post] func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) input := new(schema.TTSRequest) diff --git a/core/http/endpoints/localai/vad.go b/core/http/endpoints/localai/vad.go index c5a5d929..cd92c6ce 100644 --- a/core/http/endpoints/localai/vad.go +++ b/core/http/endpoints/localai/vad.go @@ -19,6 +19,7 @@ import ( // @Router /vad [post] func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) input := new(schema.VADRequest) diff --git a/core/http/endpoints/openai/assistant.go b/core/http/endpoints/openai/assistant.go index 1d83066a..6c1bfe85 100644 --- a/core/http/endpoints/openai/assistant.go +++ b/core/http/endpoints/openai/assistant.go @@ -76,6 +76,7 @@ type AssistantRequest struct { // @Router /v1/assistants [post] func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) request := new(AssistantRequest) if err := c.BodyParser(request); err != nil { log.Warn().AnErr("Unable to parse AssistantRequest", err) @@ -137,6 +138,7 @@ func generateRandomID() int64 { // @Router /v1/assistants [get] func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) // Because we're altering the existing assistants list we should just duplicate it for now. returnAssistants := Assistants // Parse query parameters @@ -246,6 +248,7 @@ func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelNam // @Router /v1/assistants/{assistant_id} [delete] func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) assistantID := c.Params("assistant_id") if assistantID == "" { return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") @@ -278,6 +281,7 @@ func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad // @Router /v1/assistants/{assistant_id} [get] func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) assistantID := c.Params("assistant_id") if assistantID == "" { return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") @@ -307,6 +311,7 @@ var ( func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) request := new(schema.AssistantFileRequest) if err := c.BodyParser(request); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) @@ -353,6 +358,7 @@ func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelL } return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) assistantID := c.Params("assistant_id") if assistantID == "" { return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") @@ -410,6 +416,7 @@ func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelL func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) request := new(AssistantRequest) if err := c.BodyParser(request); err != nil { log.Warn().AnErr("Unable to parse AssistantRequest", err) @@ -449,6 +456,7 @@ func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) assistantID := c.Params("assistant_id") fileId := c.Params("file_id") if assistantID == "" { @@ -503,6 +511,7 @@ func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) assistantID := c.Params("assistant_id") fileId := c.Params("file_id") if assistantID == "" { diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index c2b201bd..bbae6994 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -30,7 +30,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat var id, textContentToReturn string var created int - process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { initialMessage := schema.OpenAIResponse{ ID: id, Created: created, @@ -40,18 +40,26 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } responses <- initialMessage - ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { + usage := schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + } + if extraUsage { + usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens + usage.TimingPromptTokens = tokenUsage.TimingPromptTokens + usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration + usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing + } + resp := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, + Usage: usage, } responses <- resp @@ -59,7 +67,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat }) close(responses) } - processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { result := "" _, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { result += s @@ -90,6 +98,17 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat log.Error().Err(err).Msg("error handling question") return } + usage := schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + } + if extraUsage { + usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens + usage.TimingPromptTokens = tokenUsage.TimingPromptTokens + usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration + usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing + } resp := schema.OpenAIResponse{ ID: id, @@ -97,11 +116,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, + Usage: usage, } responses <- resp @@ -160,6 +175,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", startupOptions.MachineTag) + textContentToReturn = "" id = uuid.New().String() created = int(time.Now().Unix()) @@ -170,6 +187,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } c.Set("X-Correlation-ID", correlationID) + // Opt-in extra usage flag + extraUsage := c.Get("LocalAI-Extra-Usage", "") != "" + modelFile, input, err := readRequest(c, cl, ml, startupOptions, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) @@ -311,6 +331,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) // c.Set("Content-Type", "text/event-stream") + c.Set("LocalAI-Machine-Tag", startupOptions.MachineTag) c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") @@ -319,9 +340,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat responses := make(chan schema.OpenAIResponse) if !shouldUseFn { - go process(predInput, input, config, ml, responses) + go process(predInput, input, config, ml, responses, extraUsage) } else { - go processTools(noActionName, predInput, input, config, ml, responses) + go processTools(noActionName, predInput, input, config, ml, responses, extraUsage) } c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { @@ -449,6 +470,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat if err != nil { return err } + usage := schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + } + if extraUsage { + usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens + usage.TimingPromptTokens = tokenUsage.TimingPromptTokens + usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration + usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing + } + fmt.Println(tokenUsage) resp := &schema.OpenAIResponse{ ID: id, @@ -456,11 +489,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "chat.completion", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, + Usage: usage, } respData, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", respData) diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 04ebc847..0ee058ff 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -30,8 +30,19 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e id := uuid.New().String() created := int(time.Now().Unix()) - process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { + ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { + usage := schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + } + if extraUsage { + usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens + usage.TimingPromptTokens = tokenUsage.TimingPromptTokens + usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration + usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing + } resp := schema.OpenAIResponse{ ID: id, Created: created, @@ -43,11 +54,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e }, }, Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, + Usage: usage, } log.Debug().Msgf("Sending goroutine: %s", s) @@ -58,8 +65,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e } return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) // Add Correlation c.Set("X-Correlation-ID", id) + + // Opt-in extra usage flag + extraUsage := c.Get("LocalAI-Extra-Usage", "") != "" + modelFile, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) @@ -113,7 +125,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e responses := make(chan schema.OpenAIResponse) - go process(predInput, input, config, ml, responses) + go process(predInput, input, config, ml, responses, extraUsage) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { @@ -170,11 +182,24 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e return err } - totalTokenUsage.Prompt += tokenUsage.Prompt - totalTokenUsage.Completion += tokenUsage.Completion + totalTokenUsage.TimingPredictedTokens += tokenUsage.TimingPredictedTokens + totalTokenUsage.TimingPromptTokens += tokenUsage.TimingPromptTokens + totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration + totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing result = append(result, r...) } + usage := schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + } + if extraUsage { + usage.TimingPredictedTokens = totalTokenUsage.TimingPredictedTokens + usage.TimingPromptTokens = totalTokenUsage.TimingPromptTokens + usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration + usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing + } resp := &schema.OpenAIResponse{ ID: id, @@ -182,11 +207,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, + Usage: usage, } jsonResult, _ := json.Marshal(resp) diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index a6d609fb..e484863c 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -25,6 +25,11 @@ import ( func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) + + // Opt-in extra usage flag + extraUsage := c.Get("LocalAI-Extra-Usage", "") != "" + modelFile, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) @@ -61,8 +66,24 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat totalTokenUsage.Prompt += tokenUsage.Prompt totalTokenUsage.Completion += tokenUsage.Completion + totalTokenUsage.TimingPredictedTokens += tokenUsage.TimingPredictedTokens + totalTokenUsage.TimingPromptTokens += tokenUsage.TimingPromptTokens + totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration + totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing + result = append(result, r...) } + usage := schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + } + if extraUsage { + usage.TimingPredictedTokens = totalTokenUsage.TimingPredictedTokens + usage.TimingPromptTokens = totalTokenUsage.TimingPromptTokens + usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration + usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing + } id := uuid.New().String() created := int(time.Now().Unix()) @@ -72,11 +93,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "edit", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, + Usage: usage, } jsonResult, _ := json.Marshal(resp) diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index e247d84e..636f069e 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -23,6 +23,7 @@ import ( // @Router /v1/embeddings [post] func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) model, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) diff --git a/core/http/endpoints/openai/files.go b/core/http/endpoints/openai/files.go index bc392e73..fcf65af1 100644 --- a/core/http/endpoints/openai/files.go +++ b/core/http/endpoints/openai/files.go @@ -23,6 +23,7 @@ const UploadedFilesFile = "uploadedFiles.json" // UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) file, err := c.FormFile("file") if err != nil { return err @@ -82,6 +83,7 @@ func getNextFileId() int64 { func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) var listFiles schema.ListFiles purpose := c.Query("purpose") @@ -120,6 +122,7 @@ func getFileFromRequest(c *fiber.Ctx) (*schema.File, error) { // @Router /v1/files/{file_id} [get] func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) file, err := getFileFromRequest(c) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) @@ -142,6 +145,7 @@ type DeleteStatus struct { func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) file, err := getFileFromRequest(c) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) @@ -179,6 +183,7 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli // GetFilesContentsEndpoint func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) file, err := getFileFromRequest(c) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 3fdb64d4..ad676c33 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -66,6 +66,7 @@ func downloadFile(url string) (string, error) { // @Router /v1/images/generations [post] func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) m, input, err := readRequest(c, cl, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 80dcb3e4..7f718533 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -12,8 +12,10 @@ import ( // @Summary List and describe the various models available in the API. // @Success 200 {object} schema.ModelsDataResponse "Response" // @Router /v1/models [get] -func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) + // If blank, no filter is applied. filter := c.Query("filter") diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 4e23f804..ba43a302 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -25,6 +25,7 @@ import ( // @Router /v1/audio/transcriptions [post] func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + c.Set("LocalAI-Machine-Tag", appConfig.MachineTag) m, input, err := readRequest(c, cl, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 5ff301b6..a48ced65 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -130,6 +130,6 @@ func RegisterOpenAIRoutes(app *fiber.App, } // List models - app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader())) - app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader())) + app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) + app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) } diff --git a/core/schema/openai.go b/core/schema/openai.go index 15bcd13d..c339f6ac 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -23,6 +23,11 @@ type OpenAIUsage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + // Extra timing data, disabled by default as is't not a part of OpenAI specification + TimingPromptTokens int `json:"timing_prompt_tokens,omitempty"` + TimingPredictedTokens int `json:"timing_predicted_tokens,omitempty"` + TimingPromptProcessing float64 `json:"timing_prompt_processing,omitempty"` + TimingTokenGeneration float64 `json:"timing_token_generation,omitempty"` } type Item struct {