mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
Merge 8fea82e68b
into 91ef58ee5a
This commit is contained in:
commit
dff75bb64f
2 changed files with 115 additions and 5 deletions
|
@ -217,6 +217,7 @@ struct llama_client_slot
|
|||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
bool reranker = false;
|
||||
bool has_next_token = true;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
|
@ -535,6 +536,12 @@ struct llama_server_context
|
|||
return false;
|
||||
}
|
||||
|
||||
// Enable reranking if embeddings are enabled - moved after context initialization
|
||||
if (params.embedding) {
|
||||
params.reranking = true;
|
||||
LOG_INFO("Reranking enabled (embeddings are enabled)", {});
|
||||
}
|
||||
|
||||
if (multimodal) {
|
||||
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
||||
const int n_embd_llm = llama_model_n_embd(model);
|
||||
|
@ -1413,7 +1420,59 @@ struct llama_server_context
|
|||
queue_results.send(res);
|
||||
}
|
||||
|
||||
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
|
||||
void send_rerank(llama_client_slot &slot, const llama_batch & batch)
|
||||
{
|
||||
task_result res;
|
||||
res.id = slot.task_id;
|
||||
res.multitask_id = slot.multitask_id;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
|
||||
float score = -1e6f; // Default score if we fail to get embeddings
|
||||
|
||||
if (!params.reranking)
|
||||
{
|
||||
LOG_WARNING("reranking disabled", {
|
||||
{"params.reranking", params.reranking},
|
||||
});
|
||||
}
|
||||
else if (ctx == nullptr)
|
||||
{
|
||||
LOG_ERR("context is null, cannot perform reranking");
|
||||
res.error = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
LOG("failed to get embeddings");
|
||||
continue;
|
||||
}
|
||||
|
||||
score = embd[0];
|
||||
}
|
||||
}
|
||||
|
||||
// Format result as JSON similar to the embedding function
|
||||
res.result_json = json
|
||||
{
|
||||
{"score", score},
|
||||
{"tokens", slot.num_prompt_tokens}
|
||||
};
|
||||
|
||||
queue_results.send(res);
|
||||
}
|
||||
|
||||
void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
|
||||
{
|
||||
task_server task;
|
||||
task.id = task_id;
|
||||
|
@ -1421,6 +1480,7 @@ struct llama_server_context
|
|||
task.data = std::move(data);
|
||||
task.infill_mode = infill;
|
||||
task.embedding_mode = embedding;
|
||||
task.reranking_mode = rerank;
|
||||
task.type = TASK_TYPE_COMPLETION;
|
||||
task.multitask_id = multitask_id;
|
||||
|
||||
|
@ -1552,7 +1612,7 @@ struct llama_server_context
|
|||
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||
|
||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
||||
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
||||
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1591,6 +1651,7 @@ struct llama_server_context
|
|||
|
||||
slot->infill = task.infill_mode;
|
||||
slot->embedding = task.embedding_mode;
|
||||
slot->reranker = task.reranking_mode;
|
||||
slot->task_id = task.id;
|
||||
slot->multitask_id = task.multitask_id;
|
||||
|
||||
|
@ -2034,6 +2095,14 @@ struct llama_server_context
|
|||
continue;
|
||||
}
|
||||
|
||||
if (slot.reranker)
|
||||
{
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
|
||||
|
||||
|
@ -2489,7 +2558,7 @@ public:
|
|||
json data = parse_options(true, request, llama);
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, data, false, false, -1);
|
||||
llama.request_completion(task_id, data, false, false, false, -1);
|
||||
while (true)
|
||||
{
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
|
@ -2543,7 +2612,7 @@ public:
|
|||
json data = parse_options(false, request, llama);
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, data, false, false, -1);
|
||||
llama.request_completion(task_id, data, false, false, false, -1);
|
||||
std::string completion_text;
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
if (!result.error && result.stop) {
|
||||
|
@ -2580,7 +2649,7 @@ public:
|
|||
json data = parse_options(false, request, llama);
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1);
|
||||
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1);
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
//std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
|
||||
|
@ -2612,6 +2681,46 @@ public:
|
|||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
|
||||
// Create a JSON object with the query and documents
|
||||
json data = {
|
||||
{"prompt", request->query()},
|
||||
{"documents", request->documents()},
|
||||
{"top_n", request->top_n()}
|
||||
};
|
||||
|
||||
// Generate a new task ID
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
|
||||
// Queue the task with reranking mode enabled
|
||||
llama.request_completion(task_id, data, false, false, true, -1);
|
||||
|
||||
// Get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
llama.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (!result.error && result.stop) {
|
||||
// Set usage information
|
||||
backend::Usage* usage = rerankResult->mutable_usage();
|
||||
usage->set_total_tokens(result.result_json.value("tokens", 0));
|
||||
usage->set_prompt_tokens(result.result_json.value("tokens", 0));
|
||||
|
||||
// Get the score from the result
|
||||
float score = result.result_json.value("score", 0.0f);
|
||||
|
||||
// Create document results for each input document
|
||||
for (int i = 0; i < request->documents_size(); i++) {
|
||||
backend::DocumentResult* doc_result = rerankResult->add_results();
|
||||
doc_result->set_index(i);
|
||||
doc_result->set_text(request->documents(i));
|
||||
doc_result->set_relevance_score(score);
|
||||
}
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
|
||||
llama_client_slot* active_slot = llama.get_active_slot();
|
||||
|
||||
|
|
1
backend/cpp/llama/utils.hpp
vendored
1
backend/cpp/llama/utils.hpp
vendored
|
@ -61,6 +61,7 @@ struct task_server {
|
|||
json data;
|
||||
bool infill_mode = false;
|
||||
bool embedding_mode = false;
|
||||
bool reranking_mode = false;
|
||||
int multitask_id = -1;
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue