mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(llama.cpp): add reranking
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
41e239c67e
commit
5bf05cec1f
5 changed files with 105 additions and 0 deletions
|
@ -231,6 +231,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
params.n_parallel = 1;
|
||||
}
|
||||
|
||||
|
||||
const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS");
|
||||
if (llama_grpc_servers != NULL) {
|
||||
add_rpc_devices(std::string(llama_grpc_servers));
|
||||
|
@ -291,6 +292,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
|
||||
|
||||
params.embedding = request->embeddings();
|
||||
params.reranking = request->reranking();
|
||||
|
||||
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
||||
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||
|
@ -791,6 +793,93 @@ public:
|
|||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
|
||||
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
||||
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if (request->query().empty()) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided");
|
||||
}
|
||||
|
||||
if (request->documents_size() == 0) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
// Tokenize the query
|
||||
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, request->query(), /* add_special */ false, true)[0];
|
||||
|
||||
// Create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
std::unordered_set<int> task_ids;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
std::vector<std::string> documents;
|
||||
for (int i = 0; i < request->documents_size(); i++) {
|
||||
documents.push_back(request->documents(i));
|
||||
}
|
||||
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||
tasks.reserve(tokenized_docs.size());
|
||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
// Get the results
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
error = true;
|
||||
}, [&]() {
|
||||
return false;
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
if (error) {
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
backend::Usage* usage = rerankResult->mutable_usage();
|
||||
int total_tokens = 0;
|
||||
int prompt_tokens = 0;
|
||||
|
||||
// Create document results
|
||||
for (const auto& response : responses) {
|
||||
backend::DocumentResult* doc_result = rerankResult->add_results();
|
||||
doc_result->set_index(response.value("index", 0));
|
||||
doc_result->set_text(request->documents(response.value("index", 0)));
|
||||
doc_result->set_relevance_score(response.value("score", 0.0f));
|
||||
|
||||
// Add tokens evaluated for this document
|
||||
int tokens_evaluated = response.value("tokens_evaluated", 0);
|
||||
total_tokens += tokens_evaluated;
|
||||
prompt_tokens += tokens_evaluated;
|
||||
}
|
||||
|
||||
// Set the total tokens in usage
|
||||
usage->set_total_tokens(total_tokens);
|
||||
usage->set_prompt_tokens(prompt_tokens);
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
||||
json body = parse_options(false, request);
|
||||
body["stream"] = false;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue