mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
wip reranking llama.cpp
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
61cc76c455
commit
01e2e3dbc3
2 changed files with 61 additions and 2 deletions
|
@ -217,6 +217,7 @@ struct llama_client_slot
|
||||||
|
|
||||||
bool infill = false;
|
bool infill = false;
|
||||||
bool embedding = false;
|
bool embedding = false;
|
||||||
|
bool reranker = false;
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
bool stopped_eos = false;
|
bool stopped_eos = false;
|
||||||
|
@ -1413,7 +1414,54 @@ struct llama_server_context
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
|
void send_rerank(llama_client_slot &slot, const llama_batch & batch)
|
||||||
|
{
|
||||||
|
task_result res;
|
||||||
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
|
res.error = false;
|
||||||
|
res.stop = true;
|
||||||
|
|
||||||
|
float score = -1e6f; // Default score if we fail to get embeddings
|
||||||
|
|
||||||
|
if (!params.rerank)
|
||||||
|
{
|
||||||
|
LOG_WARNING("reranking disabled", {
|
||||||
|
{"params.rerank", params.rerank},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
|
if (embd == NULL) {
|
||||||
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embd == NULL) {
|
||||||
|
LOG("failed to get embeddings");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
score = embd[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format result as JSON similar to the embedding function
|
||||||
|
res.result_json = json
|
||||||
|
{
|
||||||
|
{"score", score},
|
||||||
|
{"tokens", slot.n_prompt_tokens}
|
||||||
|
};
|
||||||
|
|
||||||
|
queue_results.send(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
|
||||||
{
|
{
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = task_id;
|
task.id = task_id;
|
||||||
|
@ -1421,6 +1469,7 @@ struct llama_server_context
|
||||||
task.data = std::move(data);
|
task.data = std::move(data);
|
||||||
task.infill_mode = infill;
|
task.infill_mode = infill;
|
||||||
task.embedding_mode = embedding;
|
task.embedding_mode = embedding;
|
||||||
|
task.reranking_mode = rerank;
|
||||||
task.type = TASK_TYPE_COMPLETION;
|
task.type = TASK_TYPE_COMPLETION;
|
||||||
task.multitask_id = multitask_id;
|
task.multitask_id = multitask_id;
|
||||||
|
|
||||||
|
@ -1552,7 +1601,7 @@ struct llama_server_context
|
||||||
subtask_data["prompt"] = subtask_data["prompt"][i];
|
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||||
|
|
||||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
||||||
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1591,6 +1640,7 @@ struct llama_server_context
|
||||||
|
|
||||||
slot->infill = task.infill_mode;
|
slot->infill = task.infill_mode;
|
||||||
slot->embedding = task.embedding_mode;
|
slot->embedding = task.embedding_mode;
|
||||||
|
slot->reranker = task.reranking_mode;
|
||||||
slot->task_id = task.id;
|
slot->task_id = task.id;
|
||||||
slot->multitask_id = task.multitask_id;
|
slot->multitask_id = task.multitask_id;
|
||||||
|
|
||||||
|
@ -2034,6 +2084,14 @@ struct llama_server_context
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.reranker)
|
||||||
|
{
|
||||||
|
send_rerank(slot, batch_view);
|
||||||
|
slot.release();
|
||||||
|
slot.i_batch = -1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
|
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
|
||||||
|
|
||||||
|
|
1
backend/cpp/llama/utils.hpp
vendored
1
backend/cpp/llama/utils.hpp
vendored
|
@ -61,6 +61,7 @@ struct task_server {
|
||||||
json data;
|
json data;
|
||||||
bool infill_mode = false;
|
bool infill_mode = false;
|
||||||
bool embedding_mode = false;
|
bool embedding_mode = false;
|
||||||
|
bool reranking_mode = false;
|
||||||
int multitask_id = -1;
|
int multitask_id = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue