mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
Merge 5bf05cec1f
into f8fbfd4fa3
This commit is contained in:
commit
8dab879813
5 changed files with 105 additions and 0 deletions
|
@ -255,6 +255,8 @@ message ModelOptions {
|
||||||
string CacheTypeValue = 64;
|
string CacheTypeValue = 64;
|
||||||
|
|
||||||
repeated GrammarTrigger GrammarTriggers = 65;
|
repeated GrammarTrigger GrammarTriggers = 65;
|
||||||
|
|
||||||
|
bool Reranking = 71;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Result {
|
message Result {
|
||||||
|
|
|
@ -231,6 +231,7 @@ static void params_parse(const backend::ModelOptions* request,
|
||||||
params.n_parallel = 1;
|
params.n_parallel = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS");
|
const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS");
|
||||||
if (llama_grpc_servers != NULL) {
|
if (llama_grpc_servers != NULL) {
|
||||||
add_rpc_devices(std::string(llama_grpc_servers));
|
add_rpc_devices(std::string(llama_grpc_servers));
|
||||||
|
@ -291,6 +292,7 @@ static void params_parse(const backend::ModelOptions* request,
|
||||||
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
|
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
|
||||||
|
|
||||||
params.embedding = request->embeddings();
|
params.embedding = request->embeddings();
|
||||||
|
params.reranking = request->reranking();
|
||||||
|
|
||||||
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
||||||
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||||
|
@ -791,6 +793,93 @@ public:
|
||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
|
||||||
|
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
||||||
|
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
if (request->query().empty()) {
|
||||||
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (request->documents_size() == 0) {
|
||||||
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenize the query
|
||||||
|
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, request->query(), /* add_special */ false, true)[0];
|
||||||
|
|
||||||
|
// Create and queue the task
|
||||||
|
json responses = json::array();
|
||||||
|
bool error = false;
|
||||||
|
std::unordered_set<int> task_ids;
|
||||||
|
{
|
||||||
|
std::vector<server_task> tasks;
|
||||||
|
std::vector<std::string> documents;
|
||||||
|
for (int i = 0; i < request->documents_size(); i++) {
|
||||||
|
documents.push_back(request->documents(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||||
|
tasks.reserve(tokenized_docs.size());
|
||||||
|
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||||
|
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||||
|
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||||
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
task.index = i;
|
||||||
|
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
|
||||||
|
tasks.push_back(std::move(task));
|
||||||
|
}
|
||||||
|
|
||||||
|
task_ids = server_task::get_list_id(tasks);
|
||||||
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
|
ctx_server.queue_tasks.post(std::move(tasks));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the results
|
||||||
|
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||||
|
for (auto & res : results) {
|
||||||
|
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||||
|
responses.push_back(res->to_json());
|
||||||
|
}
|
||||||
|
}, [&](const json & error_data) {
|
||||||
|
error = true;
|
||||||
|
}, [&]() {
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set usage information
|
||||||
|
backend::Usage* usage = rerankResult->mutable_usage();
|
||||||
|
int total_tokens = 0;
|
||||||
|
int prompt_tokens = 0;
|
||||||
|
|
||||||
|
// Create document results
|
||||||
|
for (const auto& response : responses) {
|
||||||
|
backend::DocumentResult* doc_result = rerankResult->add_results();
|
||||||
|
doc_result->set_index(response.value("index", 0));
|
||||||
|
doc_result->set_text(request->documents(response.value("index", 0)));
|
||||||
|
doc_result->set_relevance_score(response.value("score", 0.0f));
|
||||||
|
|
||||||
|
// Add tokens evaluated for this document
|
||||||
|
int tokens_evaluated = response.value("tokens_evaluated", 0);
|
||||||
|
total_tokens += tokens_evaluated;
|
||||||
|
prompt_tokens += tokens_evaluated;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the total tokens in usage
|
||||||
|
usage->set_total_tokens(total_tokens);
|
||||||
|
usage->set_prompt_tokens(prompt_tokens);
|
||||||
|
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
||||||
json body = parse_options(false, request);
|
json body = parse_options(false, request);
|
||||||
body["stream"] = false;
|
body["stream"] = false;
|
||||||
|
|
|
@ -58,6 +58,9 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||||
if opts.Embeddings {
|
if opts.Embeddings {
|
||||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||||
}
|
}
|
||||||
|
if opts.Reranking {
|
||||||
|
llamaOpts = append(llamaOpts, llama.EnableReranking)
|
||||||
|
}
|
||||||
if opts.NGPULayers != 0 {
|
if opts.NGPULayers != 0 {
|
||||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,6 +94,11 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
lowVRAM = *c.LowVRAM
|
lowVRAM = *c.LowVRAM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reranking := false
|
||||||
|
if c.Reranking != nil {
|
||||||
|
reranking = *c.Reranking
|
||||||
|
}
|
||||||
|
|
||||||
mmap := false
|
mmap := false
|
||||||
if c.MMap != nil {
|
if c.MMap != nil {
|
||||||
mmap = *c.MMap
|
mmap = *c.MMap
|
||||||
|
@ -178,6 +183,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
RopeFreqScale: c.RopeFreqScale,
|
RopeFreqScale: c.RopeFreqScale,
|
||||||
NUMA: c.NUMA,
|
NUMA: c.NUMA,
|
||||||
Embeddings: embeddings,
|
Embeddings: embeddings,
|
||||||
|
Reranking: reranking,
|
||||||
LowVRAM: lowVRAM,
|
LowVRAM: lowVRAM,
|
||||||
NGPULayers: int32(nGPULayers),
|
NGPULayers: int32(nGPULayers),
|
||||||
MMap: mmap,
|
MMap: mmap,
|
||||||
|
|
|
@ -120,6 +120,7 @@ type LLMConfig struct {
|
||||||
MMap *bool `yaml:"mmap"`
|
MMap *bool `yaml:"mmap"`
|
||||||
MMlock *bool `yaml:"mmlock"`
|
MMlock *bool `yaml:"mmlock"`
|
||||||
LowVRAM *bool `yaml:"low_vram"`
|
LowVRAM *bool `yaml:"low_vram"`
|
||||||
|
Reranking *bool `yaml:"reranking"`
|
||||||
Grammar string `yaml:"grammar"`
|
Grammar string `yaml:"grammar"`
|
||||||
StopWords []string `yaml:"stopwords"`
|
StopWords []string `yaml:"stopwords"`
|
||||||
Cutstrings []string `yaml:"cutstrings"`
|
Cutstrings []string `yaml:"cutstrings"`
|
||||||
|
@ -372,6 +373,10 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||||
cfg.Embeddings = &falseV
|
cfg.Embeddings = &falseV
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.Reranking == nil {
|
||||||
|
cfg.Reranking = &falseV
|
||||||
|
}
|
||||||
|
|
||||||
if threads == 0 {
|
if threads == 0 {
|
||||||
// Threads can't be 0
|
// Threads can't be 0
|
||||||
threads = 4
|
threads = 4
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue