fix(llama.cpp): correctly handle embeddings in batches (#4957)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-03-07 19:29:52 +01:00 committed by GitHub
parent 69caccfa82
commit e4fa894153
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1350,7 +1350,7 @@ struct llama_server_context
queue_results.send(res); queue_results.send(res);
} }
void send_embedding(llama_client_slot &slot) void send_embedding(llama_client_slot &slot, const llama_batch & batch)
{ {
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
@ -1372,10 +1372,38 @@ struct llama_server_context
else else
{ {
const float *data = llama_get_embeddings(ctx); const float *data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd); std::vector<float> embd_res(n_embd, 0.0f);
std::vector<std::vector<float>> embedding;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
LOG("failed to get embeddings");
continue;
}
// normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
embedding.push_back(embd_res);
} else {
embedding.push_back({ embd, embd + n_embd });
}
}
// OAI compat
res.result_json = json res.result_json = json
{ {
{"embedding", embedding }, {"embedding", embedding[0] },
}; };
} }
queue_results.send(res); queue_results.send(res);
@ -1996,7 +2024,7 @@ struct llama_server_context
// prompt evaluated for embedding // prompt evaluated for embedding
if (slot.embedding) if (slot.embedding)
{ {
send_embedding(slot); send_embedding(slot, batch_view);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
continue; continue;