Make it compile

This commit is contained in:
Ettore Di Giacinto 2025-05-15 22:41:42 +02:00
parent 453eb7d1c8
commit 6381f9bda2

View file

@ -27,9 +27,16 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
// LocalAI
#include "backend.pb.h"
#include "backend.grpc.pb.h"
#include <getopt.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h> #include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h> #include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h> #include <grpcpp/health_check_service_interface.h>
#include <regex>
using grpc::Server; using grpc::Server;
@ -37,6 +44,7 @@ using grpc::ServerBuilder;
using grpc::ServerContext; using grpc::ServerContext;
using grpc::Status; using grpc::Status;
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
// END LocalAI
constexpr int HTTP_POLLING_SECONDS = 1; constexpr int HTTP_POLLING_SECONDS = 1;
@ -336,22 +344,23 @@ struct server_task {
} }
} }
// process "json_schema" and "grammar" // TODO: add back json_schema and grammar support
if (data.contains("json_schema") && !data.contains("grammar")) { // // process "json_schema" and "grammar"
try { // if (data.contains("json_schema") && !data.contains("grammar")) {
auto schema = json_value(data, "json_schema", json::object()); // try {
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); // auto schema = json_value(data, "json_schema", json::object());
params.sampling.grammar = json_schema_to_grammar(schema); // SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); // params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) { // SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); // } catch (const std::exception & e) {
} // throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
} else { // }
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); // } else {
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); // params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); // SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); // params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
} // SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
// }
{ {
auto it = data.find("chat_format"); auto it = data.find("chat_format");
@ -3558,9 +3567,6 @@ inline void signal_handler(int signal) {
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
// The class has a llama instance that is shared across all RPCs
llama_server_context llama;
static void start_llama_server(server_context& ctx_server) { static void start_llama_server(server_context& ctx_server) {
// Wait for model to be loaded first // Wait for model to be loaded first
while (!loaded_model) { while (!loaded_model) {
@ -3568,7 +3574,7 @@ static void start_llama_server(server_context& ctx_server) {
} }
ctx_server.init(); ctx_server.init();
state.store(SERVER_STATE_READY); //state.store(SERVER_STATE_READY);
LOG_INF("%s: model loaded\n", __func__); LOG_INF("%s: model loaded\n", __func__);
@ -3604,8 +3610,6 @@ static void start_llama_server(server_context& ctx_server) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif #endif
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
// this call blocks the main thread until queue_tasks.terminate() is called // this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.queue_tasks.start_loop(); ctx_server.queue_tasks.start_loop();
} }
@ -3687,8 +3691,35 @@ static std::string get_all_kv_cache_types() {
return msg.str(); return msg.str();
} }
// Adds an RPC server
// https://github.com/ggerganov/llama.cpp/compare/4dbc8b9cb71876e005724f4e8f73a3544646bcf5..3edfa7d3753c29e44b964c0ff424d2ea8d5fdee6
static void add_rpc_devices(std::string servers) {
auto rpc_servers = string_split<std::string>(servers, ',');
if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified");
}
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
throw std::invalid_argument("failed to find RPC backend");
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
throw std::invalid_argument("failed to find RPC device add function");
}
for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
ggml_backend_device_register(dev);
} else {
throw std::invalid_argument("failed to register RPC device");
}
}
}
static void params_parse(const backend::ModelOptions* request, static void params_parse(const backend::ModelOptions* request,
common_params & params, llama_server_context &llama) { common_params & params) {
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809 // this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
@ -3736,7 +3767,7 @@ static void params_parse(const backend::ModelOptions* request,
} }
if (!strcmp(optname, "gpu")) { if (!strcmp(optname, "gpu")) {
llama.has_gpu = true; // llama.has_gpu = true;
} }
} }
@ -3805,7 +3836,6 @@ static void params_parse(const backend::ModelOptions* request,
} }
if (request->grammartriggers_size() > 0) { if (request->grammartriggers_size() > 0) {
LOG_INFO("configuring grammar triggers", {});
params.sampling.grammar_lazy = true; params.sampling.grammar_lazy = true;
for (int i = 0; i < request->grammartriggers_size(); i++) { for (int i = 0; i < request->grammartriggers_size(); i++) {
common_grammar_trigger trigger; common_grammar_trigger trigger;
@ -3813,9 +3843,7 @@ static void params_parse(const backend::ModelOptions* request,
trigger.value = request->grammartriggers(i).word(); trigger.value = request->grammartriggers(i).word();
// trigger.at_start = request->grammartriggers(i).at_start(); // trigger.at_start = request->grammartriggers(i).at_start();
params.sampling.grammar_triggers.push_back(trigger); params.sampling.grammar_triggers.push_back(trigger);
LOG_INFO("grammar trigger", {
{ "word", trigger.value },
});
} }
} }
} }
@ -3857,6 +3885,8 @@ public:
result->set_message("Loading succeeded"); result->set_message("Loading succeeded");
result->set_success(true); result->set_success(true);
loaded_model = true; loaded_model = true;
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
return Status::OK; return Status::OK;
} }
@ -3876,9 +3906,23 @@ public:
std::vector<server_task> tasks; std::vector<server_task> tasks;
const auto & prompt = data.at("prompt"); const auto & prompt = data.at("prompt");
const auto type = SERVER_TASK_TYPE_COMPLETION;
// TODO: this log can become very long, put it behind a flag or think about a more compact format // TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str()); //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
std::vector<raw_buffer> files;
const auto &images_data = data.find("image_data");
if (images_data != data.end() && images_data->is_array())
{
for (const auto &img : *images_data)
{
const std::vector<uint8_t> image_buffer = base64_decode(img["data"].get<std::string>());
raw_buffer data;
data.insert(data.end(), image_buffer.begin(), image_buffer.end());
files.push_back(data);
}
}
// process files // process files
mtmd::bitmaps bitmaps; mtmd::bitmaps bitmaps;
const bool has_mtmd = ctx_server.mctx != nullptr; const bool has_mtmd = ctx_server.mctx != nullptr;
@ -3960,8 +4004,7 @@ public:
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks)); ctx_server.queue_tasks.post(std::move(tasks));
} catch (const std::exception & e) { } catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
return;
} }
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
@ -3973,9 +4016,9 @@ public:
backend::Reply reply; backend::Reply reply;
reply.set_message(completion_text); reply.set_message(completion_text);
int32_t tokens_predicted = res.value("tokens_predicted", 0); int32_t tokens_predicted = res.value("tokens_predicted", 0);
reply->set_tokens(tokens_predicted); reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res.value("tokens_evaluated", 0); int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
reply->set_prompt_tokens(tokens_evaluated); reply.set_prompt_tokens(tokens_evaluated);
if (res.contains("timings")) { if (res.contains("timings")) {
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0); double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
@ -3985,9 +4028,6 @@ public:
} }
// Log Request Correlation Id // Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});
// Send the reply // Send the reply
writer->Write(reply); writer->Write(reply);
@ -4009,10 +4049,7 @@ public:
reply.set_timing_token_generation(timing_token_generation); reply.set_timing_token_generation(timing_token_generation);
} }
// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});
// Send the reply // Send the reply
writer->Write(reply); writer->Write(reply);
@ -4024,9 +4061,9 @@ public:
reply.set_message(error_data.value("content", "")); reply.set_message(error_data.value("content", ""));
writer->Write(reply); writer->Write(reply);
return true; return true;
}, [&writer]() { }, [&]() {
// note: do not use req.is_connection_closed here because req is already destroyed // NOTE: we should try to check when the writer is closed here
return !writer->IsWritePossible(); return false;
}); });
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
@ -4035,6 +4072,7 @@ public:
} }
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
json data = parse_options(true, request);
//Raise error if embeddings is set to true //Raise error if embeddings is set to true
if (ctx_server.params_base.embedding) { if (ctx_server.params_base.embedding) {
@ -4047,9 +4085,22 @@ public:
std::vector<server_task> tasks; std::vector<server_task> tasks;
const auto & prompt = data.at("prompt"); const auto & prompt = data.at("prompt");
const auto type = SERVER_TASK_TYPE_COMPLETION;
// TODO: this log can become very long, put it behind a flag or think about a more compact format // TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str()); //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
std::vector<raw_buffer> files;
const auto &images_data = data.find("image_data");
if (images_data != data.end() && images_data->is_array())
{
for (const auto &img : *images_data)
{
const std::vector<uint8_t> image_buffer = base64_decode(img["data"].get<std::string>());
raw_buffer data;
data.insert(data.end(), image_buffer.begin(), image_buffer.end());
files.push_back(data);
}
}
// process files // process files
mtmd::bitmaps bitmaps; mtmd::bitmaps bitmaps;
const bool has_mtmd = ctx_server.mctx != nullptr; const bool has_mtmd = ctx_server.mctx != nullptr;
@ -4131,8 +4182,7 @@ public:
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(std::move(tasks)); ctx_server.queue_tasks.post(std::move(tasks));
} catch (const std::exception & e) { } catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
return;
} }
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) { ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
@ -4209,9 +4259,11 @@ public:
responses.push_back(res->to_json()); responses.push_back(res->to_json());
} }
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, error_data); return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", ""));
error = true; }, [&]() {
}, req.is_connection_closed); // NOTE: we should try to check when the writer is closed here
return false;
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
@ -4290,20 +4342,6 @@ public:
}; };
void RunServer(const std::string& server_address, server_context& ctx_server) {
BackendServiceImpl service(ctx_server);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
std::unique_ptr<Server> server(builder.BuildAndStart());
std::cout << "Server listening on " << server_address << std::endl;
server->Wait();
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
std::string server_address("localhost:50051"); std::string server_address("localhost:50051");
@ -4328,14 +4366,31 @@ int main(int argc, char** argv) {
} }
server_context ctx_server; server_context ctx_server;
BackendServiceImpl service(ctx_server);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
std::unique_ptr<Server> server(builder.BuildAndStart());
// run the HTTP server in a thread - see comment below // run the HTTP server in a thread - see comment below
std::thread t([&]() std::thread t([&]()
{ {
RunServer(server_address, ctx_server); std::cout << "Server listening on " << server_address << std::endl;
server->Wait();
return 0; return 0;
}); });
// clean up function, to be called before exit
auto clean_up = [&server, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
server->Shutdown();
ctx_server.queue_results.terminate();
llama_backend_free();
};
//); //);
start_llama_server(ctx_server); start_llama_server(ctx_server);