Add machine tag option, add extraUsage option, grpc-server -> proto -> endpoint extraUsage data is broken for now

Signed-off-by: mintyleaf <mintyleafdev@gmail.com>
This commit is contained in:
mintyleaf 2025-01-09 04:49:57 +04:00
parent 20edd44463
commit f040aa46a3
19 changed files with 192 additions and 46 deletions

View file

@ -159,6 +159,10 @@ message Reply {
bytes message = 1;
int32 tokens = 2;
int32 prompt_tokens = 3;
int32 timing_prompt_tokens = 4;
int32 timing_predicted_tokens = 5;
double timing_prompt_processing = 6;
double timing_token_generation = 7;
}
message ModelOptions {

View file

@ -2414,6 +2414,15 @@ public:
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
int32_t timing_prompt_tokens = result.result_json.value("timings", json{}).value("prompt_n", 0);
reply.set_timing_prompt_tokens(timing_prompt_tokens);
int32_t timing_predicted_tokens = result.result_json.value("timings", json{}).value("predicted_n", 0);
reply.set_timing_predicted_tokens(timing_predicted_tokens);
double timing_prompt_processing = result.result_json.value("timings", json{}).value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = result.result_json.value("timings", json{}).value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
@ -2454,6 +2463,15 @@ public:
reply->set_prompt_tokens(tokens_evaluated);
reply->set_tokens(tokens_predicted);
reply->set_message(completion_text);
int32_t timing_prompt_tokens = result.result_json.value("timings", json{}).value("prompt_n", 0);
reply->set_timing_prompt_tokens(timing_prompt_tokens);
int32_t timing_predicted_tokens = result.result_json.value("timings", json{}).value("predicted_n", 0);
reply->set_timing_predicted_tokens(timing_predicted_tokens);
double timing_prompt_processing = result.result_json.value("timings", json{}).value("prompt_ms", 0.0);
reply->set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = result.result_json.value("timings", json{}).value("predicted_ms", 0.0);
reply->set_timing_token_generation(timing_token_generation);
}
else
{

View file

@ -29,6 +29,10 @@ type LLMResponse struct {
type TokenUsage struct {
Prompt int
Completion int
TimingPromptTokens int
TimingPredictedTokens int
TimingPromptProcessing float64
TimingTokenGeneration float64
}
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
@ -123,6 +127,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
tokenUsage.Prompt = int(reply.PromptTokens)
tokenUsage.Completion = int(reply.Tokens)
tokenUsage.TimingPredictedTokens = int(reply.TimingPredictedTokens)
tokenUsage.TimingPromptTokens = int(reply.TimingPromptTokens)
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune)
@ -157,6 +165,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if tokenUsage.Completion == 0 {
tokenUsage.Completion = int(reply.Tokens)
}
tokenUsage.TimingPredictedTokens = int(reply.TimingPredictedTokens)
tokenUsage.TimingPromptTokens = int(reply.TimingPromptTokens)
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,

View file

@ -71,6 +71,7 @@ type RunCMD struct {
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
MachineTag string `env:"LOCALAI_MACHINE_TAG" help:"TODO: write a help string"`
}
func (r *RunCMD) Run(ctx *cliContext.Context) error {
@ -107,6 +108,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
config.WithLoadToMemory(r.LoadToMemory),
config.WithMachineTag(r.MachineTag),
}
if r.DisableMetricsEndpoint {

View file

@ -4,6 +4,7 @@ import (
"context"
"embed"
"encoding/json"
"os"
"regexp"
"time"
@ -65,6 +66,8 @@ type ApplicationConfig struct {
ModelsURL []string
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
MachineTag string
}
type AppOption func(*ApplicationConfig)
@ -94,6 +97,16 @@ func WithModelPath(path string) AppOption {
}
}
func WithMachineTag(tag string) AppOption {
return func(o *ApplicationConfig) {
if tag == "" {
hostname, _ := os.Hostname()
tag = hostname
}
o.MachineTag = tag
}
}
func WithCors(b bool) AppOption {
return func(o *ApplicationConfig) {
o.CORS = b

View file

@ -19,6 +19,8 @@ import (
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
req := new(schema.JINARerankRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{

View file

@ -24,6 +24,7 @@ import (
// @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
input := new(schema.TTSRequest)

View file

@ -19,6 +19,7 @@ import (
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
input := new(schema.VADRequest)

View file

@ -76,6 +76,7 @@ type AssistantRequest struct {
// @Router /v1/assistants [post]
func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
request := new(AssistantRequest)
if err := c.BodyParser(request); err != nil {
log.Warn().AnErr("Unable to parse AssistantRequest", err)
@ -137,6 +138,7 @@ func generateRandomID() int64 {
// @Router /v1/assistants [get]
func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
// Because we're altering the existing assistants list we should just duplicate it for now.
returnAssistants := Assistants
// Parse query parameters
@ -246,6 +248,7 @@ func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelNam
// @Router /v1/assistants/{assistant_id} [delete]
func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
assistantID := c.Params("assistant_id")
if assistantID == "" {
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
@ -278,6 +281,7 @@ func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
// @Router /v1/assistants/{assistant_id} [get]
func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
assistantID := c.Params("assistant_id")
if assistantID == "" {
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
@ -307,6 +311,7 @@ var (
func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
request := new(schema.AssistantFileRequest)
if err := c.BodyParser(request); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
@ -353,6 +358,7 @@ func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelL
}
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
assistantID := c.Params("assistant_id")
if assistantID == "" {
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
@ -410,6 +416,7 @@ func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelL
func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
request := new(AssistantRequest)
if err := c.BodyParser(request); err != nil {
log.Warn().AnErr("Unable to parse AssistantRequest", err)
@ -449,6 +456,7 @@ func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
assistantID := c.Params("assistant_id")
fileId := c.Params("file_id")
if assistantID == "" {
@ -503,6 +511,7 @@ func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
assistantID := c.Params("assistant_id")
fileId := c.Params("file_id")
if assistantID == "" {

View file

@ -30,7 +30,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
var id, textContentToReturn string
var created int
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
@ -40,18 +40,26 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
}
responses <- initialMessage
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
Usage: usage,
}
responses <- resp
@ -59,7 +67,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
})
close(responses)
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
@ -90,6 +98,17 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
log.Error().Err(err).Msg("error handling question")
return
}
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}
resp := schema.OpenAIResponse{
ID: id,
@ -97,11 +116,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
Usage: usage,
}
responses <- resp
@ -160,6 +175,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
}
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", startupOptions.MachineTag)
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
@ -170,6 +187,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
}
c.Set("X-Correlation-ID", correlationID)
// Opt-in extra usage flag
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
@ -311,6 +331,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("LocalAI-Machine-Tag", startupOptions.MachineTag)
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
@ -319,9 +340,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
responses := make(chan schema.OpenAIResponse)
if !shouldUseFn {
go process(predInput, input, config, ml, responses)
go process(predInput, input, config, ml, responses, extraUsage)
} else {
go processTools(noActionName, predInput, input, config, ml, responses)
go processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
@ -449,6 +470,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
if err != nil {
return err
}
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}
fmt.Println(tokenUsage)
resp := &schema.OpenAIResponse{
ID: id,
@ -456,11 +489,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
Usage: usage,
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)

View file

@ -30,8 +30,19 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
}
if extraUsage {
usage.TimingPredictedTokens = tokenUsage.TimingPredictedTokens
usage.TimingPromptTokens = tokenUsage.TimingPromptTokens
usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
@ -43,11 +54,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
Usage: usage,
}
log.Debug().Msgf("Sending goroutine: %s", s)
@ -58,8 +65,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
}
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
// Add Correlation
c.Set("X-Correlation-ID", id)
// Opt-in extra usage flag
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
@ -113,7 +125,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, ml, responses)
go process(predInput, input, config, ml, responses, extraUsage)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
@ -170,11 +182,24 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
totalTokenUsage.TimingPredictedTokens += tokenUsage.TimingPredictedTokens
totalTokenUsage.TimingPromptTokens += tokenUsage.TimingPromptTokens
totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration
totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing
result = append(result, r...)
}
usage := schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
}
if extraUsage {
usage.TimingPredictedTokens = totalTokenUsage.TimingPredictedTokens
usage.TimingPromptTokens = totalTokenUsage.TimingPromptTokens
usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing
}
resp := &schema.OpenAIResponse{
ID: id,
@ -182,11 +207,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
Usage: usage,
}
jsonResult, _ := json.Marshal(resp)

View file

@ -25,6 +25,11 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
// Opt-in extra usage flag
extraUsage := c.Get("LocalAI-Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
@ -61,8 +66,24 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
totalTokenUsage.TimingPredictedTokens += tokenUsage.TimingPredictedTokens
totalTokenUsage.TimingPromptTokens += tokenUsage.TimingPromptTokens
totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration
totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing
result = append(result, r...)
}
usage := schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
}
if extraUsage {
usage.TimingPredictedTokens = totalTokenUsage.TimingPredictedTokens
usage.TimingPromptTokens = totalTokenUsage.TimingPromptTokens
usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration
usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing
}
id := uuid.New().String()
created := int(time.Now().Unix())
@ -72,11 +93,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
Usage: usage,
}
jsonResult, _ := json.Marshal(resp)

View file

@ -23,6 +23,7 @@ import (
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
model, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)

View file

@ -23,6 +23,7 @@ const UploadedFilesFile = "uploadedFiles.json"
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
file, err := c.FormFile("file")
if err != nil {
return err
@ -82,6 +83,7 @@ func getNextFileId() int64 {
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
var listFiles schema.ListFiles
purpose := c.Query("purpose")
@ -120,6 +122,7 @@ func getFileFromRequest(c *fiber.Ctx) (*schema.File, error) {
// @Router /v1/files/{file_id} [get]
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
@ -142,6 +145,7 @@ type DeleteStatus struct {
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
@ -179,6 +183,7 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
// GetFilesContentsEndpoint
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))

View file

@ -66,6 +66,7 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)

View file

@ -12,8 +12,10 @@ import (
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
// If blank, no filter is applied.
filter := c.Query("filter")

View file

@ -25,6 +25,7 @@ import (
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Set("LocalAI-Machine-Tag", appConfig.MachineTag)
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)

View file

@ -130,6 +130,6 @@ func RegisterOpenAIRoutes(app *fiber.App,
}
// List models
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
}

View file

@ -23,6 +23,11 @@ type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
// Extra timing data, disabled by default as is't not a part of OpenAI specification
TimingPromptTokens int `json:"timing_prompt_tokens,omitempty"`
TimingPredictedTokens int `json:"timing_predicted_tokens,omitempty"`
TimingPromptProcessing float64 `json:"timing_prompt_processing,omitempty"`
TimingTokenGeneration float64 `json:"timing_token_generation,omitempty"`
}
type Item struct {