mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 18:45:00 +00:00
Merge 8fea82e68b
into 91ef58ee5a
This commit is contained in:
commit
dff75bb64f
2 changed files with 115 additions and 5 deletions
|
@ -217,6 +217,7 @@ struct llama_client_slot
|
||||||
|
|
||||||
bool infill = false;
|
bool infill = false;
|
||||||
bool embedding = false;
|
bool embedding = false;
|
||||||
|
bool reranker = false;
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
bool stopped_eos = false;
|
bool stopped_eos = false;
|
||||||
|
@ -535,6 +536,12 @@ struct llama_server_context
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enable reranking if embeddings are enabled - moved after context initialization
|
||||||
|
if (params.embedding) {
|
||||||
|
params.reranking = true;
|
||||||
|
LOG_INFO("Reranking enabled (embeddings are enabled)", {});
|
||||||
|
}
|
||||||
|
|
||||||
if (multimodal) {
|
if (multimodal) {
|
||||||
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
||||||
const int n_embd_llm = llama_model_n_embd(model);
|
const int n_embd_llm = llama_model_n_embd(model);
|
||||||
|
@ -1413,7 +1420,59 @@ struct llama_server_context
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
|
void send_rerank(llama_client_slot &slot, const llama_batch & batch)
|
||||||
|
{
|
||||||
|
task_result res;
|
||||||
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
|
res.error = false;
|
||||||
|
res.stop = true;
|
||||||
|
|
||||||
|
float score = -1e6f; // Default score if we fail to get embeddings
|
||||||
|
|
||||||
|
if (!params.reranking)
|
||||||
|
{
|
||||||
|
LOG_WARNING("reranking disabled", {
|
||||||
|
{"params.reranking", params.reranking},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else if (ctx == nullptr)
|
||||||
|
{
|
||||||
|
LOG_ERR("context is null, cannot perform reranking");
|
||||||
|
res.error = true;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
|
if (embd == NULL) {
|
||||||
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embd == NULL) {
|
||||||
|
LOG("failed to get embeddings");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
score = embd[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format result as JSON similar to the embedding function
|
||||||
|
res.result_json = json
|
||||||
|
{
|
||||||
|
{"score", score},
|
||||||
|
{"tokens", slot.num_prompt_tokens}
|
||||||
|
};
|
||||||
|
|
||||||
|
queue_results.send(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
|
||||||
{
|
{
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = task_id;
|
task.id = task_id;
|
||||||
|
@ -1421,6 +1480,7 @@ struct llama_server_context
|
||||||
task.data = std::move(data);
|
task.data = std::move(data);
|
||||||
task.infill_mode = infill;
|
task.infill_mode = infill;
|
||||||
task.embedding_mode = embedding;
|
task.embedding_mode = embedding;
|
||||||
|
task.reranking_mode = rerank;
|
||||||
task.type = TASK_TYPE_COMPLETION;
|
task.type = TASK_TYPE_COMPLETION;
|
||||||
task.multitask_id = multitask_id;
|
task.multitask_id = multitask_id;
|
||||||
|
|
||||||
|
@ -1552,7 +1612,7 @@ struct llama_server_context
|
||||||
subtask_data["prompt"] = subtask_data["prompt"][i];
|
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||||
|
|
||||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
||||||
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1591,6 +1651,7 @@ struct llama_server_context
|
||||||
|
|
||||||
slot->infill = task.infill_mode;
|
slot->infill = task.infill_mode;
|
||||||
slot->embedding = task.embedding_mode;
|
slot->embedding = task.embedding_mode;
|
||||||
|
slot->reranker = task.reranking_mode;
|
||||||
slot->task_id = task.id;
|
slot->task_id = task.id;
|
||||||
slot->multitask_id = task.multitask_id;
|
slot->multitask_id = task.multitask_id;
|
||||||
|
|
||||||
|
@ -2034,6 +2095,14 @@ struct llama_server_context
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.reranker)
|
||||||
|
{
|
||||||
|
send_rerank(slot, batch_view);
|
||||||
|
slot.release();
|
||||||
|
slot.i_batch = -1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
|
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
|
||||||
|
|
||||||
|
@ -2489,7 +2558,7 @@ public:
|
||||||
json data = parse_options(true, request, llama);
|
json data = parse_options(true, request, llama);
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
llama.request_completion(task_id, data, false, false, -1);
|
llama.request_completion(task_id, data, false, false, false, -1);
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
@ -2543,7 +2612,7 @@ public:
|
||||||
json data = parse_options(false, request, llama);
|
json data = parse_options(false, request, llama);
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
llama.request_completion(task_id, data, false, false, -1);
|
llama.request_completion(task_id, data, false, false, false, -1);
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
if (!result.error && result.stop) {
|
if (!result.error && result.stop) {
|
||||||
|
@ -2580,7 +2649,7 @@ public:
|
||||||
json data = parse_options(false, request, llama);
|
json data = parse_options(false, request, llama);
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1);
|
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1);
|
||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
//std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
|
//std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
|
||||||
|
@ -2612,6 +2681,46 @@ public:
|
||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
|
||||||
|
// Create a JSON object with the query and documents
|
||||||
|
json data = {
|
||||||
|
{"prompt", request->query()},
|
||||||
|
{"documents", request->documents()},
|
||||||
|
{"top_n", request->top_n()}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate a new task ID
|
||||||
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
|
||||||
|
// Queue the task with reranking mode enabled
|
||||||
|
llama.request_completion(task_id, data, false, false, true, -1);
|
||||||
|
|
||||||
|
// Get the result
|
||||||
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
|
if (!result.error && result.stop) {
|
||||||
|
// Set usage information
|
||||||
|
backend::Usage* usage = rerankResult->mutable_usage();
|
||||||
|
usage->set_total_tokens(result.result_json.value("tokens", 0));
|
||||||
|
usage->set_prompt_tokens(result.result_json.value("tokens", 0));
|
||||||
|
|
||||||
|
// Get the score from the result
|
||||||
|
float score = result.result_json.value("score", 0.0f);
|
||||||
|
|
||||||
|
// Create document results for each input document
|
||||||
|
for (int i = 0; i < request->documents_size(); i++) {
|
||||||
|
backend::DocumentResult* doc_result = rerankResult->add_results();
|
||||||
|
doc_result->set_index(i);
|
||||||
|
doc_result->set_text(request->documents(i));
|
||||||
|
doc_result->set_relevance_score(score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
|
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
|
||||||
llama_client_slot* active_slot = llama.get_active_slot();
|
llama_client_slot* active_slot = llama.get_active_slot();
|
||||||
|
|
||||||
|
|
1
backend/cpp/llama/utils.hpp
vendored
1
backend/cpp/llama/utils.hpp
vendored
|
@ -61,6 +61,7 @@ struct task_server {
|
||||||
json data;
|
json data;
|
||||||
bool infill_mode = false;
|
bool infill_mode = false;
|
||||||
bool embedding_mode = false;
|
bool embedding_mode = false;
|
||||||
|
bool reranking_mode = false;
|
||||||
int multitask_id = -1;
|
int multitask_id = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue