This commit is contained in:
Ettore Di Giacinto 2025-05-07 16:53:00 +00:00 committed by GitHub
commit dff75bb64f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 115 additions and 5 deletions

View file

@ -217,6 +217,7 @@ struct llama_client_slot
bool infill = false; bool infill = false;
bool embedding = false; bool embedding = false;
bool reranker = false;
bool has_next_token = true; bool has_next_token = true;
bool truncated = false; bool truncated = false;
bool stopped_eos = false; bool stopped_eos = false;
@ -535,6 +536,12 @@ struct llama_server_context
return false; return false;
} }
// Enable reranking if embeddings are enabled - moved after context initialization
if (params.embedding) {
params.reranking = true;
LOG_INFO("Reranking enabled (embeddings are enabled)", {});
}
if (multimodal) { if (multimodal) {
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
const int n_embd_llm = llama_model_n_embd(model); const int n_embd_llm = llama_model_n_embd(model);
@ -1413,7 +1420,59 @@ struct llama_server_context
queue_results.send(res); queue_results.send(res);
} }
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) void send_rerank(llama_client_slot &slot, const llama_batch & batch)
{
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;
float score = -1e6f; // Default score if we fail to get embeddings
if (!params.reranking)
{
LOG_WARNING("reranking disabled", {
{"params.reranking", params.reranking},
});
}
else if (ctx == nullptr)
{
LOG_ERR("context is null, cannot perform reranking");
res.error = true;
}
else
{
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
LOG("failed to get embeddings");
continue;
}
score = embd[0];
}
}
// Format result as JSON similar to the embedding function
res.result_json = json
{
{"score", score},
{"tokens", slot.num_prompt_tokens}
};
queue_results.send(res);
}
void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
{ {
task_server task; task_server task;
task.id = task_id; task.id = task_id;
@ -1421,6 +1480,7 @@ struct llama_server_context
task.data = std::move(data); task.data = std::move(data);
task.infill_mode = infill; task.infill_mode = infill;
task.embedding_mode = embedding; task.embedding_mode = embedding;
task.reranking_mode = rerank;
task.type = TASK_TYPE_COMPLETION; task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id; task.multitask_id = multitask_id;
@ -1552,7 +1612,7 @@ struct llama_server_context
subtask_data["prompt"] = subtask_data["prompt"][i]; subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.) // subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id);
} }
} }
@ -1591,6 +1651,7 @@ struct llama_server_context
slot->infill = task.infill_mode; slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode; slot->embedding = task.embedding_mode;
slot->reranker = task.reranking_mode;
slot->task_id = task.id; slot->task_id = task.id;
slot->multitask_id = task.multitask_id; slot->multitask_id = task.multitask_id;
@ -2034,6 +2095,14 @@ struct llama_server_context
continue; continue;
} }
if (slot.reranker)
{
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue;
}
completion_token_output result; completion_token_output result;
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i); const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
@ -2489,7 +2558,7 @@ public:
json data = parse_options(true, request, llama); json data = parse_options(true, request, llama);
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1); llama.request_completion(task_id, data, false, false, false, -1);
while (true) while (true)
{ {
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
@ -2543,7 +2612,7 @@ public:
json data = parse_options(false, request, llama); json data = parse_options(false, request, llama);
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1); llama.request_completion(task_id, data, false, false, false, -1);
std::string completion_text; std::string completion_text;
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) { if (!result.error && result.stop) {
@ -2580,7 +2649,7 @@ public:
json data = parse_options(false, request, llama); json data = parse_options(false, request, llama);
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1); llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
//std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl; //std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
@ -2612,6 +2681,46 @@ public:
return grpc::Status::OK; return grpc::Status::OK;
} }
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
// Create a JSON object with the query and documents
json data = {
{"prompt", request->query()},
{"documents", request->documents()},
{"top_n", request->top_n()}
};
// Generate a new task ID
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
// Queue the task with reranking mode enabled
llama.request_completion(task_id, data, false, false, true, -1);
// Get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
if (!result.error && result.stop) {
// Set usage information
backend::Usage* usage = rerankResult->mutable_usage();
usage->set_total_tokens(result.result_json.value("tokens", 0));
usage->set_prompt_tokens(result.result_json.value("tokens", 0));
// Get the score from the result
float score = result.result_json.value("score", 0.0f);
// Create document results for each input document
for (int i = 0; i < request->documents_size(); i++) {
backend::DocumentResult* doc_result = rerankResult->add_results();
doc_result->set_index(i);
doc_result->set_text(request->documents(i));
doc_result->set_relevance_score(score);
}
}
return grpc::Status::OK;
}
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) { grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
llama_client_slot* active_slot = llama.get_active_slot(); llama_client_slot* active_slot = llama.get_active_slot();

View file

@ -61,6 +61,7 @@ struct task_server {
json data; json data;
bool infill_mode = false; bool infill_mode = false;
bool embedding_mode = false; bool embedding_mode = false;
bool reranking_mode = false;
int multitask_id = -1; int multitask_id = -1;
}; };