From 6381f9bda2230e2e07876e3674d4ae5bc781cf8d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 15 May 2025 22:41:42 +0200 Subject: [PATCH] Make it compile --- backend/cpp/llama/grpc-server.cpp | 183 +++++++++++++++++++----------- 1 file changed, 119 insertions(+), 64 deletions(-) diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index aeb8d409..f02fd20e 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -27,9 +27,16 @@ #include #include #include + +// LocalAI + +#include "backend.pb.h" +#include "backend.grpc.pb.h" +#include #include #include #include +#include using grpc::Server; @@ -37,6 +44,7 @@ using grpc::ServerBuilder; using grpc::ServerContext; using grpc::Status; using json = nlohmann::ordered_json; +// END LocalAI constexpr int HTTP_POLLING_SECONDS = 1; @@ -336,22 +344,23 @@ struct server_task { } } - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - params.sampling.grammar = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); - } catch (const std::exception & e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); - params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); - SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); - } + // TODO: add back json_schema and grammar support + // // process "json_schema" and "grammar" + // if (data.contains("json_schema") && !data.contains("grammar")) { + // try { + // auto schema = json_value(data, "json_schema", json::object()); + // SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + // params.sampling.grammar = json_schema_to_grammar(schema); + // SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + // } catch (const std::exception & e) { + // throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + // } + // } else { + // params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + // SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + // params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + // SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + // } { auto it = data.find("chat_format"); @@ -3558,9 +3567,6 @@ inline void signal_handler(int signal) { bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model -// The class has a llama instance that is shared across all RPCs -llama_server_context llama; - static void start_llama_server(server_context& ctx_server) { // Wait for model to be loaded first while (!loaded_model) { @@ -3568,7 +3574,7 @@ static void start_llama_server(server_context& ctx_server) { } ctx_server.init(); - state.store(SERVER_STATE_READY); + //state.store(SERVER_STATE_READY); LOG_INF("%s: model loaded\n", __func__); @@ -3604,8 +3610,6 @@ static void start_llama_server(server_context& ctx_server) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port); - // this call blocks the main thread until queue_tasks.terminate() is called ctx_server.queue_tasks.start_loop(); } @@ -3687,8 +3691,35 @@ static std::string get_all_kv_cache_types() { return msg.str(); } + +// Adds an RPC server +// https://github.com/ggerganov/llama.cpp/compare/4dbc8b9cb71876e005724f4e8f73a3544646bcf5..3edfa7d3753c29e44b964c0ff424d2ea8d5fdee6 +static void add_rpc_devices(std::string servers) { + auto rpc_servers = string_split(servers, ','); + if (rpc_servers.empty()) { + throw std::invalid_argument("no RPC servers specified"); + } + ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); + if (!rpc_reg) { + throw std::invalid_argument("failed to find RPC backend"); + } + typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); + ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); + if (!ggml_backend_rpc_add_device_fn) { + throw std::invalid_argument("failed to find RPC device add function"); + } + for (const auto & server : rpc_servers) { + ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); + if (dev) { + ggml_backend_device_register(dev); + } else { + throw std::invalid_argument("failed to register RPC device"); + } + } +} + static void params_parse(const backend::ModelOptions* request, - common_params & params, llama_server_context &llama) { + common_params & params) { // this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809 @@ -3736,7 +3767,7 @@ static void params_parse(const backend::ModelOptions* request, } if (!strcmp(optname, "gpu")) { - llama.has_gpu = true; + // llama.has_gpu = true; } } @@ -3805,7 +3836,6 @@ static void params_parse(const backend::ModelOptions* request, } if (request->grammartriggers_size() > 0) { - LOG_INFO("configuring grammar triggers", {}); params.sampling.grammar_lazy = true; for (int i = 0; i < request->grammartriggers_size(); i++) { common_grammar_trigger trigger; @@ -3813,9 +3843,7 @@ static void params_parse(const backend::ModelOptions* request, trigger.value = request->grammartriggers(i).word(); // trigger.at_start = request->grammartriggers(i).at_start(); params.sampling.grammar_triggers.push_back(trigger); - LOG_INFO("grammar trigger", { - { "word", trigger.value }, - }); + } } } @@ -3857,6 +3885,8 @@ public: result->set_message("Loading succeeded"); result->set_success(true); loaded_model = true; + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + return Status::OK; } @@ -3876,9 +3906,23 @@ public: std::vector tasks; const auto & prompt = data.at("prompt"); + const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + std::vector files; + const auto &images_data = data.find("image_data"); + if (images_data != data.end() && images_data->is_array()) + { + for (const auto &img : *images_data) + { + const std::vector image_buffer = base64_decode(img["data"].get()); + raw_buffer data; + data.insert(data.end(), image_buffer.begin(), image_buffer.end()); + files.push_back(data); + } + } + // process files mtmd::bitmaps bitmaps; const bool has_mtmd = ctx_server.mctx != nullptr; @@ -3960,8 +4004,7 @@ public: ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(std::move(tasks)); } catch (const std::exception & e) { - res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { @@ -3973,9 +4016,9 @@ public: backend::Reply reply; reply.set_message(completion_text); int32_t tokens_predicted = res.value("tokens_predicted", 0); - reply->set_tokens(tokens_predicted); + reply.set_tokens(tokens_predicted); int32_t tokens_evaluated = res.value("tokens_evaluated", 0); - reply->set_prompt_tokens(tokens_evaluated); + reply.set_prompt_tokens(tokens_evaluated); if (res.contains("timings")) { double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0); @@ -3985,10 +4028,7 @@ public: } // Log Request Correlation Id - LOG_VERBOSE("correlation:", { - { "id", data["correlation_id"] } - }); - + // Send the reply writer->Write(reply); } @@ -4009,10 +4049,7 @@ public: reply.set_timing_token_generation(timing_token_generation); } - // Log Request Correlation Id - LOG_VERBOSE("correlation:", { - { "id", data["correlation_id"] } - }); + // Send the reply writer->Write(reply); @@ -4024,9 +4061,9 @@ public: reply.set_message(error_data.value("content", "")); writer->Write(reply); return true; - }, [&writer]() { - // note: do not use req.is_connection_closed here because req is already destroyed - return !writer->IsWritePossible(); + }, [&]() { + // NOTE: we should try to check when the writer is closed here + return false; }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); @@ -4035,7 +4072,8 @@ public: } grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { - + json data = parse_options(true, request); + //Raise error if embeddings is set to true if (ctx_server.params_base.embedding) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode"); @@ -4047,9 +4085,22 @@ public: std::vector tasks; const auto & prompt = data.at("prompt"); + const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + std::vector files; + const auto &images_data = data.find("image_data"); + if (images_data != data.end() && images_data->is_array()) + { + for (const auto &img : *images_data) + { + const std::vector image_buffer = base64_decode(img["data"].get()); + raw_buffer data; + data.insert(data.end(), image_buffer.begin(), image_buffer.end()); + files.push_back(data); + } + } // process files mtmd::bitmaps bitmaps; const bool has_mtmd = ctx_server.mctx != nullptr; @@ -4131,8 +4182,7 @@ public: ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(std::move(tasks)); } catch (const std::exception & e) { - res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { @@ -4209,9 +4259,11 @@ public: responses.push_back(res->to_json()); } }, [&](const json & error_data) { - res_error(res, error_data); - error = true; - }, req.is_connection_closed); + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", "")); + }, [&]() { + // NOTE: we should try to check when the writer is closed here + return false; + }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); @@ -4290,20 +4342,6 @@ public: }; -void RunServer(const std::string& server_address, server_context& ctx_server) { - BackendServiceImpl service(ctx_server); - - ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB - builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB - builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB - std::unique_ptr server(builder.BuildAndStart()); - std::cout << "Server listening on " << server_address << std::endl; - server->Wait(); -} - int main(int argc, char** argv) { std::string server_address("localhost:50051"); @@ -4328,14 +4366,31 @@ int main(int argc, char** argv) { } server_context ctx_server; + BackendServiceImpl service(ctx_server); + ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB + builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB + builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB + std::unique_ptr server(builder.BuildAndStart()); // run the HTTP server in a thread - see comment below std::thread t([&]() { - RunServer(server_address, ctx_server); + std::cout << "Server listening on " << server_address << std::endl; + server->Wait(); return 0; }); + // clean up function, to be called before exit + auto clean_up = [&server, &ctx_server]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + server->Shutdown(); + ctx_server.queue_results.terminate(); + llama_backend_free(); + }; + //); start_llama_server(ctx_server);