mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(grpc): return consumed token count and update response accordingly (#2035)
Fixes: #1920
This commit is contained in:
parent
de3a1a0a8e
commit
e843d7df0e
4 changed files with 20 additions and 4 deletions
|
@ -114,6 +114,8 @@ message PredictOptions {
|
||||||
// The response message containing the result
|
// The response message containing the result
|
||||||
message Reply {
|
message Reply {
|
||||||
bytes message = 1;
|
bytes message = 1;
|
||||||
|
int32 tokens = 2;
|
||||||
|
int32 prompt_tokens = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ModelOptions {
|
message ModelOptions {
|
||||||
|
|
|
@ -2332,6 +2332,10 @@ public:
|
||||||
std::string completion_text = result.result_json.value("content", "");
|
std::string completion_text = result.result_json.value("content", "");
|
||||||
|
|
||||||
reply.set_message(completion_text);
|
reply.set_message(completion_text);
|
||||||
|
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
|
||||||
|
reply.set_tokens(tokens_predicted);
|
||||||
|
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
|
||||||
|
reply.set_prompt_tokens(tokens_evaluated);
|
||||||
|
|
||||||
// Send the reply
|
// Send the reply
|
||||||
writer->Write(reply);
|
writer->Write(reply);
|
||||||
|
@ -2357,6 +2361,10 @@ public:
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
if (!result.error && result.stop) {
|
if (!result.error && result.stop) {
|
||||||
completion_text = result.result_json.value("content", "");
|
completion_text = result.result_json.value("content", "");
|
||||||
|
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
|
||||||
|
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
|
||||||
|
reply->set_prompt_tokens(tokens_evaluated);
|
||||||
|
reply->set_tokens(tokens_predicted);
|
||||||
reply->set_message(completion_text);
|
reply->set_message(completion_text);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
|
|
@ -189,6 +189,12 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest,
|
||||||
} else {
|
} else {
|
||||||
go func() {
|
go func() {
|
||||||
reply, err := inferenceModel.Predict(ctx, grpcPredOpts)
|
reply, err := inferenceModel.Predict(ctx, grpcPredOpts)
|
||||||
|
if tokenUsage.Prompt == 0 {
|
||||||
|
tokenUsage.Prompt = int(reply.PromptTokens)
|
||||||
|
}
|
||||||
|
if tokenUsage.Completion == 0 {
|
||||||
|
tokenUsage.Completion = int(reply.Tokens)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
|
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
|
||||||
close(rawResultChannel)
|
close(rawResultChannel)
|
||||||
|
|
|
@ -160,7 +160,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest
|
||||||
|
|
||||||
bc, request, err := oais.getConfig(request)
|
bc, request, err := oais.getConfig(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("[oais::GenerateTextFromRequest] error getting configuration: %q", err)
|
log.Error().Err(err).Msgf("[oais::GenerateTextFromRequest] error getting configuration")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest
|
||||||
// If any of the setup goroutines experienced an error, quit early here.
|
// If any of the setup goroutines experienced an error, quit early here.
|
||||||
if setupError != nil {
|
if setupError != nil {
|
||||||
go func() {
|
go func() {
|
||||||
log.Error().Msgf("[OAIS GenerateTextFromRequest] caught an error during setup: %q", setupError)
|
log.Error().Err(setupError).Msgf("[OAIS GenerateTextFromRequest] caught an error during setup")
|
||||||
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError}
|
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError}
|
||||||
close(rawFinalResultChannel)
|
close(rawFinalResultChannel)
|
||||||
}()
|
}()
|
||||||
|
@ -603,7 +603,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche
|
||||||
Usage: schema.OpenAIUsage{
|
Usage: schema.OpenAIUsage{
|
||||||
PromptTokens: rawResult.Value.Usage.Prompt,
|
PromptTokens: rawResult.Value.Usage.Prompt,
|
||||||
CompletionTokens: rawResult.Value.Usage.Completion,
|
CompletionTokens: rawResult.Value.Usage.Completion,
|
||||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
|
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -644,7 +644,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche
|
||||||
Usage: schema.OpenAIUsage{
|
Usage: schema.OpenAIUsage{
|
||||||
PromptTokens: rawResult.Value.Usage.Prompt,
|
PromptTokens: rawResult.Value.Usage.Prompt,
|
||||||
CompletionTokens: rawResult.Value.Usage.Completion,
|
CompletionTokens: rawResult.Value.Usage.Completion,
|
||||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
|
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue