diff --git a/backend/backend.proto b/backend/backend.proto index cdf09bf2..9021a353 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -255,6 +255,8 @@ message ModelOptions { string CacheTypeValue = 64; repeated GrammarTrigger GrammarTriggers = 65; + + bool Reranking = 71; } message Result { diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index e6dc4b8f..be277bfa 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -231,6 +231,7 @@ static void params_parse(const backend::ModelOptions* request, params.n_parallel = 1; } + const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS"); if (llama_grpc_servers != NULL) { add_rpc_devices(std::string(llama_grpc_servers)); @@ -291,6 +292,7 @@ static void params_parse(const backend::ModelOptions* request, params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops) params.embedding = request->embeddings(); + params.reranking = request->reranking(); if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } @@ -791,6 +793,93 @@ public: return grpc::Status::OK; } + grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) { + if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); + } + + // Validate request + if (request->query().empty()) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided"); + } + + if (request->documents_size() == 0) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array"); + } + + // Tokenize the query + llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, request->query(), /* add_special */ false, true)[0]; + + // Create and queue the task + json responses = json::array(); + bool error = false; + std::unordered_set task_ids; + { + std::vector tasks; + std::vector documents; + for (int i = 0; i < request->documents_size(); i++) { + documents.push_back(request->documents(i)); + } + + auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); + tasks.reserve(tokenized_docs.size()); + for (size_t i = 0; i < tokenized_docs.size(); i++) { + auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr); + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } + + // Get the results + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + }, [&](const json & error_data) { + error = true; + }, [&]() { + return false; + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + + if (error) { + return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); + } + + // Set usage information + backend::Usage* usage = rerankResult->mutable_usage(); + int total_tokens = 0; + int prompt_tokens = 0; + + // Create document results + for (const auto& response : responses) { + backend::DocumentResult* doc_result = rerankResult->add_results(); + doc_result->set_index(response.value("index", 0)); + doc_result->set_text(request->documents(response.value("index", 0))); + doc_result->set_relevance_score(response.value("score", 0.0f)); + + // Add tokens evaluated for this document + int tokens_evaluated = response.value("tokens_evaluated", 0); + total_tokens += tokens_evaluated; + prompt_tokens += tokens_evaluated; + } + + // Set the total tokens in usage + usage->set_total_tokens(total_tokens); + usage->set_prompt_tokens(prompt_tokens); + + return grpc::Status::OK; + } + grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) { json body = parse_options(false, request); body["stream"] = false; diff --git a/backend/go/llm/llama/llama.go b/backend/go/llm/llama/llama.go index 33eb708b..011023fe 100644 --- a/backend/go/llm/llama/llama.go +++ b/backend/go/llm/llama/llama.go @@ -58,6 +58,9 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { if opts.Embeddings { llamaOpts = append(llamaOpts, llama.EnableEmbeddings) } + if opts.Reranking { + llamaOpts = append(llamaOpts, llama.EnableReranking) + } if opts.NGPULayers != 0 { llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers))) } diff --git a/core/backend/options.go b/core/backend/options.go index ab602b1d..7d4754c7 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -94,6 +94,11 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { lowVRAM = *c.LowVRAM } + reranking := false + if c.Reranking != nil { + reranking = *c.Reranking + } + mmap := false if c.MMap != nil { mmap = *c.MMap @@ -178,6 +183,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { RopeFreqScale: c.RopeFreqScale, NUMA: c.NUMA, Embeddings: embeddings, + Reranking: reranking, LowVRAM: lowVRAM, NGPULayers: int32(nGPULayers), MMap: mmap, diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 5c436400..ec0f2812 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -120,6 +120,7 @@ type LLMConfig struct { MMap *bool `yaml:"mmap"` MMlock *bool `yaml:"mmlock"` LowVRAM *bool `yaml:"low_vram"` + Reranking *bool `yaml:"reranking"` Grammar string `yaml:"grammar"` StopWords []string `yaml:"stopwords"` Cutstrings []string `yaml:"cutstrings"` @@ -372,6 +373,10 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.Embeddings = &falseV } + if cfg.Reranking == nil { + cfg.Reranking = &falseV + } + if threads == 0 { // Threads can't be 0 threads = 4