mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-19 18:15:00 +00:00
wire to grpc
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
01e2e3dbc3
commit
8fea82e68b
1 changed files with 58 additions and 7 deletions
|
@ -536,6 +536,12 @@ struct llama_server_context
|
|||
return false;
|
||||
}
|
||||
|
||||
// Enable reranking if embeddings are enabled - moved after context initialization
|
||||
if (params.embedding) {
|
||||
params.reranking = true;
|
||||
LOG_INFO("Reranking enabled (embeddings are enabled)", {});
|
||||
}
|
||||
|
||||
if (multimodal) {
|
||||
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
||||
const int n_embd_llm = llama_model_n_embd(model);
|
||||
|
@ -1424,11 +1430,16 @@ struct llama_server_context
|
|||
|
||||
float score = -1e6f; // Default score if we fail to get embeddings
|
||||
|
||||
if (!params.rerank)
|
||||
if (!params.reranking)
|
||||
{
|
||||
LOG_WARNING("reranking disabled", {
|
||||
{"params.rerank", params.rerank},
|
||||
});
|
||||
{"params.reranking", params.reranking},
|
||||
});
|
||||
}
|
||||
else if (ctx == nullptr)
|
||||
{
|
||||
LOG_ERR("context is null, cannot perform reranking");
|
||||
res.error = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1455,7 +1466,7 @@ struct llama_server_context
|
|||
res.result_json = json
|
||||
{
|
||||
{"score", score},
|
||||
{"tokens", slot.n_prompt_tokens}
|
||||
{"tokens", slot.num_prompt_tokens}
|
||||
};
|
||||
|
||||
queue_results.send(res);
|
||||
|
@ -2547,7 +2558,7 @@ public:
|
|||
json data = parse_options(true, request, llama);
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, data, false, false, -1);
|
||||
llama.request_completion(task_id, data, false, false, false, -1);
|
||||
while (true)
|
||||
{
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
|
@ -2601,7 +2612,7 @@ public:
|
|||
json data = parse_options(false, request, llama);
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, data, false, false, -1);
|
||||
llama.request_completion(task_id, data, false, false, false, -1);
|
||||
std::string completion_text;
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
if (!result.error && result.stop) {
|
||||
|
@ -2638,7 +2649,7 @@ public:
|
|||
json data = parse_options(false, request, llama);
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1);
|
||||
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1);
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
//std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
|
||||
|
@ -2670,6 +2681,46 @@ public:
|
|||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
|
||||
// Create a JSON object with the query and documents
|
||||
json data = {
|
||||
{"prompt", request->query()},
|
||||
{"documents", request->documents()},
|
||||
{"top_n", request->top_n()}
|
||||
};
|
||||
|
||||
// Generate a new task ID
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
|
||||
// Queue the task with reranking mode enabled
|
||||
llama.request_completion(task_id, data, false, false, true, -1);
|
||||
|
||||
// Get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
llama.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (!result.error && result.stop) {
|
||||
// Set usage information
|
||||
backend::Usage* usage = rerankResult->mutable_usage();
|
||||
usage->set_total_tokens(result.result_json.value("tokens", 0));
|
||||
usage->set_prompt_tokens(result.result_json.value("tokens", 0));
|
||||
|
||||
// Get the score from the result
|
||||
float score = result.result_json.value("score", 0.0f);
|
||||
|
||||
// Create document results for each input document
|
||||
for (int i = 0; i < request->documents_size(); i++) {
|
||||
backend::DocumentResult* doc_result = rerankResult->add_results();
|
||||
doc_result->set_index(i);
|
||||
doc_result->set_text(request->documents(i));
|
||||
doc_result->set_relevance_score(score);
|
||||
}
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
|
||||
llama_client_slot* active_slot = llama.get_active_slot();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue