diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 790eebbc..22b8576e 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -536,6 +536,12 @@ struct llama_server_context return false; } + // Enable reranking if embeddings are enabled - moved after context initialization + if (params.embedding) { + params.reranking = true; + LOG_INFO("Reranking enabled (embeddings are enabled)", {}); + } + if (multimodal) { const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); const int n_embd_llm = llama_model_n_embd(model); @@ -1424,11 +1430,16 @@ struct llama_server_context float score = -1e6f; // Default score if we fail to get embeddings - if (!params.rerank) + if (!params.reranking) { LOG_WARNING("reranking disabled", { - {"params.rerank", params.rerank}, - }); + {"params.reranking", params.reranking}, + }); + } + else if (ctx == nullptr) + { + LOG_ERR("context is null, cannot perform reranking"); + res.error = true; } else { @@ -1455,7 +1466,7 @@ struct llama_server_context res.result_json = json { {"score", score}, - {"tokens", slot.n_prompt_tokens} + {"tokens", slot.num_prompt_tokens} }; queue_results.send(res); @@ -2547,7 +2558,7 @@ public: json data = parse_options(true, request, llama); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); + llama.request_completion(task_id, data, false, false, false, -1); while (true) { task_result result = llama.queue_results.recv(task_id); @@ -2601,7 +2612,7 @@ public: json data = parse_options(false, request, llama); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); + llama.request_completion(task_id, data, false, false, false, -1); std::string completion_text; task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { @@ -2638,7 +2649,7 @@ public: json data = parse_options(false, request, llama); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1); + llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1); // get the result task_result result = llama.queue_results.recv(task_id); //std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl; @@ -2670,6 +2681,46 @@ public: return grpc::Status::OK; } + grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) { + // Create a JSON object with the query and documents + json data = { + {"prompt", request->query()}, + {"documents", request->documents()}, + {"top_n", request->top_n()} + }; + + // Generate a new task ID + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + + // Queue the task with reranking mode enabled + llama.request_completion(task_id, data, false, false, true, -1); + + // Get the result + task_result result = llama.queue_results.recv(task_id); + llama.queue_results.remove_waiting_task_id(task_id); + + if (!result.error && result.stop) { + // Set usage information + backend::Usage* usage = rerankResult->mutable_usage(); + usage->set_total_tokens(result.result_json.value("tokens", 0)); + usage->set_prompt_tokens(result.result_json.value("tokens", 0)); + + // Get the score from the result + float score = result.result_json.value("score", 0.0f); + + // Create document results for each input document + for (int i = 0; i < request->documents_size(); i++) { + backend::DocumentResult* doc_result = rerankResult->add_results(); + doc_result->set_index(i); + doc_result->set_text(request->documents(i)); + doc_result->set_relevance_score(score); + } + } + + return grpc::Status::OK; + } + grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) { llama_client_slot* active_slot = llama.get_active_slot();