// llama.cpp gRPC C++ backend server // // Ettore Di Giacinto and llama.cpp authors // // This is a gRPC server for llama.cpp compatible with the LocalAI proto // Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP (https://github.com/ggerganov/llama.cpp/tree/master/examples/server), // but modified to work with gRPC // #include "server.cpp" // LocalAI #include "backend.pb.h" #include "backend.grpc.pb.h" #include #include #include #include #include using grpc::Server; using grpc::ServerBuilder; using grpc::ServerContext; using grpc::Status; // END LocalAI ///////////////////////////////// //////////////////////////////// //////// LOCALAI code starts below here ///////////////////////////////// //////////////////////////////// bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model static void start_llama_server(server_context& ctx_server) { LOG_INF("%s: starting llama server\n", __func__); LOG_INF("%s: waiting for model to be loaded\n", __func__); // Wait for model to be loaded first while (!loaded_model) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } ctx_server.init(); //state.store(SERVER_STATE_READY); LOG_INF("%s: model loaded\n", __func__); // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, common_chat_templates_source(ctx_server.chat_templates.get()), common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str()); // Reset the chat templates // TODO: We should make this configurable by respecting the option that is already present in LocalAI for vLLM ctx_server.chat_templates.reset(); ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { ctx_server.process_single_task(std::move(task)); }); ctx_server.queue_tasks.on_update_slots([&ctx_server]() { ctx_server.update_slots(); }); shutdown_handler = [&](int) { // this will unblock start_loop() ctx_server.queue_tasks.terminate(); }; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; sigemptyset (&sigint_action.sa_mask); sigint_action.sa_flags = 0; sigaction(SIGINT, &sigint_action, NULL); sigaction(SIGTERM, &sigint_action, NULL); #elif defined (_WIN32) auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; }; SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif // this call blocks the main thread until queue_tasks.terminate() is called ctx_server.queue_tasks.start_loop(); } json parse_options(bool streaming, const backend::PredictOptions* predict) { // Create now a json data from the prediction options instead // json data; data["stream"] = streaming; data["cache_prompt"] = predict->promptcacheall(); data["n_predict"] = predict->tokens() == 0 ? -1 : predict->tokens(); data["top_k"] = predict->topk(); data["top_p"] = predict->topp(); data["typical_p"] = predict->typicalp(); data["temperature"] = predict->temperature(); data["repeat_last_n"] = predict->repeat(); data["repeat_penalty"] = predict->penalty(); data["frequency_penalty"] = predict->frequencypenalty(); data["presence_penalty"] = predict->presencepenalty(); data["mirostat"] = predict->mirostat(); data["mirostat_tau"] = predict->mirostattau(); data["mirostat_eta"] = predict->mirostateta(); data["n_keep"] = predict->nkeep(); data["seed"] = predict->seed(); data["grammar"] = predict->grammar(); data["prompt"] = predict->prompt(); data["ignore_eos"] = predict->ignoreeos(); data["embeddings"] = predict->embeddings(); // TODO: add back json_schema and let this be controlled by the user // data["json_schema"] = predict->jsonschema(); // Add the correlationid to json data data["correlation_id"] = predict->correlationid(); // for each image in the request, add the image data // for (int i = 0; i < predict->images_size(); i++) { data["image_data"].push_back(json { {"id", i}, {"data", predict->images(i)}, }); } data["stop"] = predict->stopprompts(); // data["n_probs"] = predict->nprobs(); //TODO: images, return data; } const std::vector kv_cache_types = { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, }; static ggml_type kv_cache_type_from_str(const std::string & s) { for (const auto & type : kv_cache_types) { if (ggml_type_name(type) == s) { return type; } } throw std::runtime_error("Unsupported cache type: " + s); } static std::string get_all_kv_cache_types() { std::ostringstream msg; for (const auto & type : kv_cache_types) { msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", "); } return msg.str(); } // Adds an RPC server // https://github.com/ggerganov/llama.cpp/compare/4dbc8b9cb71876e005724f4e8f73a3544646bcf5..3edfa7d3753c29e44b964c0ff424d2ea8d5fdee6 static void add_rpc_devices(std::string servers) { auto rpc_servers = string_split(servers, ','); if (rpc_servers.empty()) { throw std::invalid_argument("no RPC servers specified"); } ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); if (!rpc_reg) { throw std::invalid_argument("failed to find RPC backend"); } typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); if (!ggml_backend_rpc_add_device_fn) { throw std::invalid_argument("failed to find RPC device add function"); } for (const auto & server : rpc_servers) { ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); if (dev) { ggml_backend_device_register(dev); } else { throw std::invalid_argument("failed to register RPC device"); } } } static void params_parse(const backend::ModelOptions* request, common_params & params) { // this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809 params.model.path = request->modelfile(); if (!request->mmproj().empty()) { // get the directory of modelfile std::string model_dir = params.model.path.substr(0, params.model.path.find_last_of("/\\")); params.mmproj.path = model_dir + "/"+ request->mmproj(); } // params.model_alias ?? params.model_alias = request->modelfile(); if (!request->cachetypekey().empty()) { params.cache_type_k = kv_cache_type_from_str(request->cachetypekey()); } if (!request->cachetypevalue().empty()) { params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue()); } params.n_ctx = request->contextsize(); //params.memory_f16 = request->f16memory(); params.cpuparams.n_threads = request->threads(); params.n_gpu_layers = request->ngpulayers(); params.n_batch = request->nbatch(); // Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1 //params.n_parallel = 1; const char *env_parallel = std::getenv("LLAMACPP_PARALLEL"); if (env_parallel != NULL) { params.n_parallel = std::stoi(env_parallel); params.cont_batching = true; } else { params.n_parallel = 1; } const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS"); if (llama_grpc_servers != NULL) { add_rpc_devices(std::string(llama_grpc_servers)); } // decode options. Options are in form optname:optvale, or if booleans only optname. for (int i = 0; i < request->options_size(); i++) { std::string opt = request->options(i); char *optname = strtok(&opt[0], ":"); char *optval = strtok(NULL, ":"); if (optval == NULL) { optval = "true"; } if (!strcmp(optname, "gpu")) { // llama.has_gpu = true; } } // TODO: Add yarn if (!request->tensorsplit().empty()) { std::string arg_next = request->tensorsplit(); // split string by , and / const std::regex regex{ R"([,/]+)" }; std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; std::vector split_arg{ it, {} }; GGML_ASSERT(split_arg.size() <= llama_max_devices()); for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) { if (i_device < split_arg.size()) { params.tensor_split[i_device] = std::stof(split_arg[i_device]); } else { params.tensor_split[i_device] = 0.0f; } } } if (!request->maingpu().empty()) { params.main_gpu = std::stoi(request->maingpu()); } if (!request->loraadapter().empty() && !request->lorabase().empty()) { float scale_factor = 1.0f; if (request->lorascale() != 0.0f) { scale_factor = request->lorascale(); } // get the directory of modelfile std::string model_dir = params.model.path.substr(0, params.model.path.find_last_of("/\\")); params.lora_adapters.push_back({ model_dir + "/"+request->loraadapter(), scale_factor }); } params.use_mlock = request->mlock(); params.use_mmap = request->mmap(); params.flash_attn = request->flashattention(); params.no_kv_offload = request->nokvoffload(); params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops) params.embedding = request->embeddings(); params.reranking = request->reranking(); if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } if ( request->yarnextfactor() != 0.0f ) { params.yarn_ext_factor = request->yarnextfactor(); } if ( request->yarnattnfactor() != 0.0f ) { params.yarn_attn_factor = request->yarnattnfactor(); } if ( request->yarnbetafast() != 0.0f ) { params.yarn_beta_fast = request->yarnbetafast(); } if ( request->yarnbetaslow() != 0.0f ) { params.yarn_beta_slow = request->yarnbetaslow(); } if ( request->ropefreqbase() != 0.0f ) { params.rope_freq_base = request->ropefreqbase(); } if ( request->ropefreqscale() != 0.0f ) { params.rope_freq_scale = request->ropefreqscale(); } if (request->grammartriggers_size() > 0) { params.sampling.grammar_lazy = true; for (int i = 0; i < request->grammartriggers_size(); i++) { common_grammar_trigger trigger; trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; trigger.value = request->grammartriggers(i).word(); // trigger.at_start = request->grammartriggers(i).at_start(); params.sampling.grammar_triggers.push_back(trigger); } } } // GRPC Server start class BackendServiceImpl final : public backend::Backend::Service { private: server_context& ctx_server; public: BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {} grpc::Status Health(ServerContext* context, const backend::HealthMessage* request, backend::Reply* reply) { // Implement Health RPC reply->set_message("OK"); return Status::OK; } grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) { // Implement LoadModel RPC common_params params; params_parse(request, params); common_init(); llama_backend_init(); llama_numa_init(params.numa); LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); LOG_INF("\n"); LOG_INF("%s\n", common_params_get_system_info(params).c_str()); LOG_INF("\n"); // load the model if (!ctx_server.load_model(params)) { result->set_message("Failed loading model"); result->set_success(false); return Status::CANCELLED; } //ctx_server.init(); result->set_message("Loading succeeded"); result->set_success(true); loaded_model = true; ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; return Status::OK; } grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { json data = parse_options(true, request); //Raise error if embeddings is set to true if (ctx_server.params_base.embedding) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode"); } auto completion_id = gen_chatcmplid(); std::unordered_set task_ids; try { std::vector tasks; const auto & prompt = data.at("prompt"); const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); std::vector files; const auto &images_data = data.find("image_data"); if (images_data != data.end() && images_data->is_array()) { for (const auto &img : *images_data) { auto decoded_data = base64_decode(img["data"].get()); files.push_back(decoded_data); } } // process files mtmd::bitmaps bitmaps; const bool has_mtmd = ctx_server.mctx != nullptr; { if (!has_mtmd && !files.empty()) { throw std::runtime_error("This server does not support multimodal"); } for (auto & file : files) { mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); if (!bmp.ptr) { throw std::runtime_error("Failed to load image"); } // calculate bitmap hash (for KV caching) std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); bmp.set_id(hash.c_str()); bitmaps.entries.push_back(std::move(bmp)); } } // process prompt std::vector inputs; if (!prompt.is_string()) { throw std::runtime_error("prompt must be a string"); } if (has_mtmd) { // multimodal std::string prompt_str = prompt.get(); mtmd_input_text inp_txt = { prompt_str.c_str(), /* add_special */ true, /* parse_special */ true, }; mtmd::input_chunks chunks(mtmd_input_chunks_init()); auto bitmaps_c_ptr = bitmaps.c_ptr(); int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks.ptr.get(), &inp_txt, bitmaps_c_ptr.data(), bitmaps_c_ptr.size()); if (tokenized != 0) { throw std::runtime_error("Failed to tokenize prompt"); } server_tokens tmp(chunks, true); inputs.push_back(std::move(tmp)); } else { // non-multimodal version auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (auto & p : tokenized_prompts) { auto tmp = server_tokens(p, ctx_server.mctx != nullptr); inputs.push_back(std::move(tmp)); } } tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = OAICOMPAT_TYPE_NONE; task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(std::move(task)); } task_ids = server_task::get_list_id(tasks); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { json res_json = result->to_json(); if (res_json.is_array()) { for (const auto & res : res_json) { std::string completion_text = res.value("content", ""); backend::Reply reply; reply.set_message(completion_text); int32_t tokens_predicted = res.value("tokens_predicted", 0); reply.set_tokens(tokens_predicted); int32_t tokens_evaluated = res.value("tokens_evaluated", 0); reply.set_prompt_tokens(tokens_evaluated); if (res.contains("timings")) { double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0); reply.set_timing_prompt_processing(timing_prompt_processing); double timing_token_generation = res.at("timings").value("predicted_ms", 0.0); reply.set_timing_token_generation(timing_token_generation); } // Log Request Correlation Id // Send the reply writer->Write(reply); } } else { std::string completion_text = res_json.value("content", ""); backend::Reply reply; reply.set_message(completion_text); int32_t tokens_predicted = res_json.value("tokens_predicted", 0); reply.set_tokens(tokens_predicted); int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0); reply.set_prompt_tokens(tokens_evaluated); if (res_json.contains("timings")) { double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0); reply.set_timing_prompt_processing(timing_prompt_processing); double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0); reply.set_timing_token_generation(timing_token_generation); } // Send the reply writer->Write(reply); } return true; }, [&](const json & error_data) { backend::Reply reply; reply.set_message(error_data.value("content", "")); writer->Write(reply); return true; }, [&]() { // NOTE: we should try to check when the writer is closed here return false; }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); return grpc::Status::OK; } grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { json data = parse_options(true, request); data["stream"] = false; //Raise error if embeddings is set to true if (ctx_server.params_base.embedding) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in Predict mode"); } std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl; auto completion_id = gen_chatcmplid(); std::unordered_set task_ids; try { std::vector tasks; const auto & prompt = data.at("prompt"); const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); std::vector files; const auto &images_data = data.find("image_data"); // std::cout << "[PREDICT] Images data: " << images_data->dump(2) << std::endl; if (images_data != data.end() && images_data->is_array()) { std::cout << "[PREDICT] Processing " << images_data->size() << " images" << std::endl; for (const auto &img : *images_data) { std::cout << "[PREDICT] Processing image" << std::endl; auto decoded_data = base64_decode(img["data"].get()); files.push_back(decoded_data); } } // process files mtmd::bitmaps bitmaps; const bool has_mtmd = ctx_server.mctx != nullptr; { if (!has_mtmd && !files.empty()) { throw std::runtime_error("This server does not support multimodal"); } for (auto & file : files) { mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); if (!bmp.ptr) { throw std::runtime_error("Failed to load image"); } // calculate bitmap hash (for KV caching) std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); bmp.set_id(hash.c_str()); bitmaps.entries.push_back(std::move(bmp)); } } // process prompt std::vector inputs; if (!prompt.is_string()) { std::cout << "[PREDICT] Prompt must be a string" << std::endl; throw std::runtime_error("prompt must be a string"); } if (has_mtmd) { // multimodal std::string prompt_str = prompt.get(); mtmd_input_text inp_txt = { prompt_str.c_str(), /* add_special */ true, /* parse_special */ true, }; mtmd::input_chunks chunks(mtmd_input_chunks_init()); auto bitmaps_c_ptr = bitmaps.c_ptr(); int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks.ptr.get(), &inp_txt, bitmaps_c_ptr.data(), bitmaps_c_ptr.size()); if (tokenized != 0) { std::cout << "[PREDICT] Failed to tokenize prompt" << std::endl; throw std::runtime_error("Failed to tokenize prompt"); } server_tokens tmp(chunks, true); inputs.push_back(std::move(tmp)); } else { // non-multimodal version auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (auto & p : tokenized_prompts) { auto tmp = server_tokens(p, ctx_server.mctx != nullptr); inputs.push_back(std::move(tmp)); } } tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = OAICOMPAT_TYPE_NONE; task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(std::move(task)); } task_ids = server_task::get_list_id(tasks); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } std::cout << "[DEBUG] Waiting for results..." << std::endl; ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl; if (results.size() == 1) { // single result reply->set_message(results[0]->to_json().value("content", "")); int32_t tokens_predicted = results[0]->to_json().value("tokens_predicted", 0); reply->set_tokens(tokens_predicted); int32_t tokens_evaluated = results[0]->to_json().value("tokens_evaluated", 0); reply->set_prompt_tokens(tokens_evaluated); if (results[0]->to_json().contains("timings")) { double timing_prompt_processing = results[0]->to_json().at("timings").value("prompt_ms", 0.0); reply->set_timing_prompt_processing(timing_prompt_processing); double timing_token_generation = results[0]->to_json().at("timings").value("predicted_ms", 0.0); reply->set_timing_token_generation(timing_token_generation); } } else { // multiple results (multitask) json arr = json::array(); for (auto & res : results) { arr.push_back(res->to_json().value("content", "")); } reply->set_message(arr); } }, [&](const json & error_data) { std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl; reply->set_message(error_data.value("content", "")); }, [&]() { return false; }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); std::cout << "[DEBUG] Predict request completed successfully" << std::endl; return grpc::Status::OK; } grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) { json body = parse_options(false, request); body["stream"] = false; /* if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Pooling type 'none' is not OAI compatible. Please use a different pooling type"); } */ // for the shape of input/content, see tokenize_input_prompts() json prompt = body.at("prompt"); auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Input content cannot be empty"); } } // create and queue the task json responses = json::array(); bool error = false; std::unordered_set task_ids; { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr); // OAI-compat task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING; tasks.push_back(std::move(task)); } task_ids = server_task::get_list_id(tasks); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(std::move(tasks)); } // get the result ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); } }, [&](const json & error_data) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", "")); }, [&]() { // NOTE: we should try to check when the writer is closed here return false; }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); if (error) { return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } std::vector embeddings = responses[0].value("embedding", std::vector()); // loop the vector and set the embeddings results for (int i = 0; i < embeddings.size(); i++) { embeddingResult->add_embeddings(embeddings[i]); } return grpc::Status::OK; } grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) { if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) { return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); } // Validate request if (request->query().empty()) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided"); } if (request->documents_size() == 0) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array"); } // Tokenize the query llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, request->query(), /* add_special */ false, true)[0]; // Create and queue the task json responses = json::array(); bool error = false; std::unordered_set task_ids; { std::vector tasks; std::vector documents; for (int i = 0; i < request->documents_size(); i++) { documents.push_back(request->documents(i)); } auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr); tasks.push_back(std::move(task)); } task_ids = server_task::get_list_id(tasks); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(std::move(tasks)); } // Get the results ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); } }, [&](const json & error_data) { error = true; }, [&]() { return false; }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); if (error) { return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } // Set usage information backend::Usage* usage = rerankResult->mutable_usage(); int total_tokens = 0; int prompt_tokens = 0; // Create document results for (const auto& response : responses) { backend::DocumentResult* doc_result = rerankResult->add_results(); doc_result->set_index(response.value("index", 0)); doc_result->set_text(request->documents(response.value("index", 0))); doc_result->set_relevance_score(response.value("score", 0.0f)); // Add tokens evaluated for this document int tokens_evaluated = response.value("tokens_evaluated", 0); total_tokens += tokens_evaluated; prompt_tokens += tokens_evaluated; } // Set the total tokens in usage usage->set_total_tokens(total_tokens); usage->set_prompt_tokens(prompt_tokens); return grpc::Status::OK; } grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) { json body = parse_options(false, request); body["stream"] = false; json tokens_response = json::array(); if (body.count("prompt") != 0) { const bool add_special = json_value(body, "add_special", false); const bool with_pieces = json_value(body, "with_pieces", false); llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true); for (const auto& token : tokens) { std::string piece = common_token_to_piece(ctx_server.ctx, token); response->add_tokens(token); } } return grpc::Status::OK; } grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) { // request slots data using task queue int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); task.id = task_id; ctx_server.queue_results.add_waiting_task_id(task_id); ctx_server.queue_tasks.post(std::move(task), true); // high-priority task } // get the result server_task_result_ptr result = ctx_server.queue_results.recv(task_id); ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { // Handle case when no active slot exists response->set_slot_id(0); response->set_prompt_json_for_slot(""); response->set_tokens_per_second(0); response->set_tokens_generated(0); response->set_prompt_tokens_processed(0); return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } // TODO: get rid of this dynamic_cast auto res_metrics = dynamic_cast(result.get()); GGML_ASSERT(res_metrics != nullptr); // Populate the response with metrics response->set_slot_id(0); response->set_prompt_json_for_slot(""); response->set_tokens_per_second(res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.); response->set_tokens_generated(res_metrics->n_tokens_predicted_total); response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total); return grpc::Status::OK; } }; int main(int argc, char** argv) { std::string server_address("localhost:50051"); // Define long and short options struct option long_options[] = { {"addr", required_argument, nullptr, 'a'}, {nullptr, 0, nullptr, 0} }; // Parse command-line arguments int option; int option_index = 0; while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) { switch (option) { case 'a': server_address = optarg; break; default: std::cerr << "Usage: " << argv[0] << " [--addr=
] or [-a
]" << std::endl; return 1; } } server_context ctx_server; BackendServiceImpl service(ctx_server); ServerBuilder builder; builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.RegisterService(&service); builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB std::unique_ptr server(builder.BuildAndStart()); // run the HTTP server in a thread - see comment below std::thread t([&]() { std::cout << "Server listening on " << server_address << std::endl; server->Wait(); return 0; }); // clean up function, to be called before exit auto clean_up = [&server, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); server->Shutdown(); ctx_server.queue_results.terminate(); llama_backend_free(); }; //); start_llama_server(ctx_server); std::cout << "stopping" << std::endl; clean_up(); t.join(); return 0; }