mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
fix(llama.cpp): correctly handle embeddings in batches (#4957)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
69caccfa82
commit
e4fa894153
1 changed files with 32 additions and 4 deletions
|
@ -1350,7 +1350,7 @@ struct llama_server_context
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_embedding(llama_client_slot &slot)
|
void send_embedding(llama_client_slot &slot, const llama_batch & batch)
|
||||||
{
|
{
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
|
@ -1372,10 +1372,38 @@ struct llama_server_context
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const float *data = llama_get_embeddings(ctx);
|
const float *data = llama_get_embeddings(ctx);
|
||||||
std::vector<float> embedding(data, data + n_embd);
|
std::vector<float> embd_res(n_embd, 0.0f);
|
||||||
|
std::vector<std::vector<float>> embedding;
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
|
if (embd == NULL) {
|
||||||
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embd == NULL) {
|
||||||
|
LOG("failed to get embeddings");
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize only when there is pooling
|
||||||
|
// TODO: configurable
|
||||||
|
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||||
|
embedding.push_back(embd_res);
|
||||||
|
} else {
|
||||||
|
embedding.push_back({ embd, embd + n_embd });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAI compat
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"embedding", embedding },
|
{"embedding", embedding[0] },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
|
@ -1996,7 +2024,7 @@ struct llama_server_context
|
||||||
// prompt evaluated for embedding
|
// prompt evaluated for embedding
|
||||||
if (slot.embedding)
|
if (slot.embedding)
|
||||||
{
|
{
|
||||||
send_embedding(slot);
|
send_embedding(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
continue;
|
continue;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue