mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-19 18:15:00 +00:00
feat(llama.cpp): add reranking
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
41e239c67e
commit
5bf05cec1f
5 changed files with 105 additions and 0 deletions
|
@ -255,6 +255,8 @@ message ModelOptions {
|
|||
string CacheTypeValue = 64;
|
||||
|
||||
repeated GrammarTrigger GrammarTriggers = 65;
|
||||
|
||||
bool Reranking = 71;
|
||||
}
|
||||
|
||||
message Result {
|
||||
|
|
|
@ -231,6 +231,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
params.n_parallel = 1;
|
||||
}
|
||||
|
||||
|
||||
const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS");
|
||||
if (llama_grpc_servers != NULL) {
|
||||
add_rpc_devices(std::string(llama_grpc_servers));
|
||||
|
@ -291,6 +292,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
|
||||
|
||||
params.embedding = request->embeddings();
|
||||
params.reranking = request->reranking();
|
||||
|
||||
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
||||
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||
|
@ -791,6 +793,93 @@ public:
|
|||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
|
||||
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
||||
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if (request->query().empty()) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided");
|
||||
}
|
||||
|
||||
if (request->documents_size() == 0) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
// Tokenize the query
|
||||
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, request->query(), /* add_special */ false, true)[0];
|
||||
|
||||
// Create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
std::unordered_set<int> task_ids;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
std::vector<std::string> documents;
|
||||
for (int i = 0; i < request->documents_size(); i++) {
|
||||
documents.push_back(request->documents(i));
|
||||
}
|
||||
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||
tasks.reserve(tokenized_docs.size());
|
||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
task_ids = server_task::get_list_id(tasks);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
// Get the results
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
for (auto & res : results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
error = true;
|
||||
}, [&]() {
|
||||
return false;
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
if (error) {
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
backend::Usage* usage = rerankResult->mutable_usage();
|
||||
int total_tokens = 0;
|
||||
int prompt_tokens = 0;
|
||||
|
||||
// Create document results
|
||||
for (const auto& response : responses) {
|
||||
backend::DocumentResult* doc_result = rerankResult->add_results();
|
||||
doc_result->set_index(response.value("index", 0));
|
||||
doc_result->set_text(request->documents(response.value("index", 0)));
|
||||
doc_result->set_relevance_score(response.value("score", 0.0f));
|
||||
|
||||
// Add tokens evaluated for this document
|
||||
int tokens_evaluated = response.value("tokens_evaluated", 0);
|
||||
total_tokens += tokens_evaluated;
|
||||
prompt_tokens += tokens_evaluated;
|
||||
}
|
||||
|
||||
// Set the total tokens in usage
|
||||
usage->set_total_tokens(total_tokens);
|
||||
usage->set_prompt_tokens(prompt_tokens);
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
||||
json body = parse_options(false, request);
|
||||
body["stream"] = false;
|
||||
|
|
|
@ -58,6 +58,9 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
|||
if opts.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
if opts.Reranking {
|
||||
llamaOpts = append(llamaOpts, llama.EnableReranking)
|
||||
}
|
||||
if opts.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
||||
}
|
||||
|
|
|
@ -94,6 +94,11 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
|||
lowVRAM = *c.LowVRAM
|
||||
}
|
||||
|
||||
reranking := false
|
||||
if c.Reranking != nil {
|
||||
reranking = *c.Reranking
|
||||
}
|
||||
|
||||
mmap := false
|
||||
if c.MMap != nil {
|
||||
mmap = *c.MMap
|
||||
|
@ -178,6 +183,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
|||
RopeFreqScale: c.RopeFreqScale,
|
||||
NUMA: c.NUMA,
|
||||
Embeddings: embeddings,
|
||||
Reranking: reranking,
|
||||
LowVRAM: lowVRAM,
|
||||
NGPULayers: int32(nGPULayers),
|
||||
MMap: mmap,
|
||||
|
|
|
@ -120,6 +120,7 @@ type LLMConfig struct {
|
|||
MMap *bool `yaml:"mmap"`
|
||||
MMlock *bool `yaml:"mmlock"`
|
||||
LowVRAM *bool `yaml:"low_vram"`
|
||||
Reranking *bool `yaml:"reranking"`
|
||||
Grammar string `yaml:"grammar"`
|
||||
StopWords []string `yaml:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings"`
|
||||
|
@ -372,6 +373,10 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
|||
cfg.Embeddings = &falseV
|
||||
}
|
||||
|
||||
if cfg.Reranking == nil {
|
||||
cfg.Reranking = &falseV
|
||||
}
|
||||
|
||||
if threads == 0 {
|
||||
// Threads can't be 0
|
||||
threads = 4
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue