mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
Make it compile
This commit is contained in:
parent
453eb7d1c8
commit
6381f9bda2
1 changed files with 119 additions and 64 deletions
|
@ -27,9 +27,16 @@
|
|||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
// LocalAI
|
||||
|
||||
#include "backend.pb.h"
|
||||
#include "backend.grpc.pb.h"
|
||||
#include <getopt.h>
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/health_check_service_interface.h>
|
||||
#include <regex>
|
||||
|
||||
|
||||
using grpc::Server;
|
||||
|
@ -37,6 +44,7 @@ using grpc::ServerBuilder;
|
|||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
using json = nlohmann::ordered_json;
|
||||
// END LocalAI
|
||||
|
||||
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||
|
||||
|
@ -336,22 +344,23 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
}
|
||||
// TODO: add back json_schema and grammar support
|
||||
// // process "json_schema" and "grammar"
|
||||
// if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
// try {
|
||||
// auto schema = json_value(data, "json_schema", json::object());
|
||||
// SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
|
||||
// params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
// SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
|
||||
// } catch (const std::exception & e) {
|
||||
// throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
// }
|
||||
// } else {
|
||||
// params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
// SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
// params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
|
||||
// SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
|
||||
// }
|
||||
|
||||
{
|
||||
auto it = data.find("chat_format");
|
||||
|
@ -3558,9 +3567,6 @@ inline void signal_handler(int signal) {
|
|||
|
||||
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
|
||||
|
||||
// The class has a llama instance that is shared across all RPCs
|
||||
llama_server_context llama;
|
||||
|
||||
static void start_llama_server(server_context& ctx_server) {
|
||||
// Wait for model to be loaded first
|
||||
while (!loaded_model) {
|
||||
|
@ -3568,7 +3574,7 @@ static void start_llama_server(server_context& ctx_server) {
|
|||
}
|
||||
|
||||
ctx_server.init();
|
||||
state.store(SERVER_STATE_READY);
|
||||
//state.store(SERVER_STATE_READY);
|
||||
|
||||
LOG_INF("%s: model loaded\n", __func__);
|
||||
|
||||
|
@ -3604,8 +3610,6 @@ static void start_llama_server(server_context& ctx_server) {
|
|||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
// this call blocks the main thread until queue_tasks.terminate() is called
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
}
|
||||
|
@ -3687,8 +3691,35 @@ static std::string get_all_kv_cache_types() {
|
|||
return msg.str();
|
||||
}
|
||||
|
||||
|
||||
// Adds an RPC server
|
||||
// https://github.com/ggerganov/llama.cpp/compare/4dbc8b9cb71876e005724f4e8f73a3544646bcf5..3edfa7d3753c29e44b964c0ff424d2ea8d5fdee6
|
||||
static void add_rpc_devices(std::string servers) {
|
||||
auto rpc_servers = string_split<std::string>(servers, ',');
|
||||
if (rpc_servers.empty()) {
|
||||
throw std::invalid_argument("no RPC servers specified");
|
||||
}
|
||||
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
throw std::invalid_argument("failed to find RPC backend");
|
||||
}
|
||||
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
|
||||
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
||||
if (!ggml_backend_rpc_add_device_fn) {
|
||||
throw std::invalid_argument("failed to find RPC device add function");
|
||||
}
|
||||
for (const auto & server : rpc_servers) {
|
||||
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
||||
if (dev) {
|
||||
ggml_backend_device_register(dev);
|
||||
} else {
|
||||
throw std::invalid_argument("failed to register RPC device");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void params_parse(const backend::ModelOptions* request,
|
||||
common_params & params, llama_server_context &llama) {
|
||||
common_params & params) {
|
||||
|
||||
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
|
||||
|
||||
|
@ -3736,7 +3767,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
}
|
||||
|
||||
if (!strcmp(optname, "gpu")) {
|
||||
llama.has_gpu = true;
|
||||
// llama.has_gpu = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3805,7 +3836,6 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
}
|
||||
|
||||
if (request->grammartriggers_size() > 0) {
|
||||
LOG_INFO("configuring grammar triggers", {});
|
||||
params.sampling.grammar_lazy = true;
|
||||
for (int i = 0; i < request->grammartriggers_size(); i++) {
|
||||
common_grammar_trigger trigger;
|
||||
|
@ -3813,9 +3843,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
trigger.value = request->grammartriggers(i).word();
|
||||
// trigger.at_start = request->grammartriggers(i).at_start();
|
||||
params.sampling.grammar_triggers.push_back(trigger);
|
||||
LOG_INFO("grammar trigger", {
|
||||
{ "word", trigger.value },
|
||||
});
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3857,6 +3885,8 @@ public:
|
|||
result->set_message("Loading succeeded");
|
||||
result->set_success(true);
|
||||
loaded_model = true;
|
||||
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
||||
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
|
@ -3876,9 +3906,23 @@ public:
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
const auto & prompt = data.at("prompt");
|
||||
const auto type = SERVER_TASK_TYPE_COMPLETION;
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
|
||||
std::vector<raw_buffer> files;
|
||||
const auto &images_data = data.find("image_data");
|
||||
if (images_data != data.end() && images_data->is_array())
|
||||
{
|
||||
for (const auto &img : *images_data)
|
||||
{
|
||||
const std::vector<uint8_t> image_buffer = base64_decode(img["data"].get<std::string>());
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), image_buffer.begin(), image_buffer.end());
|
||||
files.push_back(data);
|
||||
}
|
||||
}
|
||||
|
||||
// process files
|
||||
mtmd::bitmaps bitmaps;
|
||||
const bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
|
@ -3960,8 +4004,7 @@ public:
|
|||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
|
||||
}
|
||||
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
|
||||
|
@ -3973,9 +4016,9 @@ public:
|
|||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res.value("tokens_predicted", 0);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res.contains("timings")) {
|
||||
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
|
||||
|
@ -3985,10 +4028,7 @@ public:
|
|||
}
|
||||
|
||||
// Log Request Correlation Id
|
||||
LOG_VERBOSE("correlation:", {
|
||||
{ "id", data["correlation_id"] }
|
||||
});
|
||||
|
||||
|
||||
// Send the reply
|
||||
writer->Write(reply);
|
||||
}
|
||||
|
@ -4009,10 +4049,7 @@ public:
|
|||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Log Request Correlation Id
|
||||
LOG_VERBOSE("correlation:", {
|
||||
{ "id", data["correlation_id"] }
|
||||
});
|
||||
|
||||
|
||||
// Send the reply
|
||||
writer->Write(reply);
|
||||
|
@ -4024,9 +4061,9 @@ public:
|
|||
reply.set_message(error_data.value("content", ""));
|
||||
writer->Write(reply);
|
||||
return true;
|
||||
}, [&writer]() {
|
||||
// note: do not use req.is_connection_closed here because req is already destroyed
|
||||
return !writer->IsWritePossible();
|
||||
}, [&]() {
|
||||
// NOTE: we should try to check when the writer is closed here
|
||||
return false;
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
@ -4035,7 +4072,8 @@ public:
|
|||
}
|
||||
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
||||
|
||||
json data = parse_options(true, request);
|
||||
|
||||
//Raise error if embeddings is set to true
|
||||
if (ctx_server.params_base.embedding) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode");
|
||||
|
@ -4047,9 +4085,22 @@ public:
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
const auto & prompt = data.at("prompt");
|
||||
const auto type = SERVER_TASK_TYPE_COMPLETION;
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
|
||||
std::vector<raw_buffer> files;
|
||||
const auto &images_data = data.find("image_data");
|
||||
if (images_data != data.end() && images_data->is_array())
|
||||
{
|
||||
for (const auto &img : *images_data)
|
||||
{
|
||||
const std::vector<uint8_t> image_buffer = base64_decode(img["data"].get<std::string>());
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), image_buffer.begin(), image_buffer.end());
|
||||
files.push_back(data);
|
||||
}
|
||||
}
|
||||
// process files
|
||||
mtmd::bitmaps bitmaps;
|
||||
const bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
|
@ -4131,8 +4182,7 @@ public:
|
|||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what());
|
||||
}
|
||||
|
||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||
|
@ -4209,9 +4259,11 @@ public:
|
|||
responses.push_back(res->to_json());
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
error = true;
|
||||
}, req.is_connection_closed);
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", ""));
|
||||
}, [&]() {
|
||||
// NOTE: we should try to check when the writer is closed here
|
||||
return false;
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
|
||||
|
@ -4290,20 +4342,6 @@ public:
|
|||
};
|
||||
|
||||
|
||||
void RunServer(const std::string& server_address, server_context& ctx_server) {
|
||||
BackendServiceImpl service(ctx_server);
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
std::cout << "Server listening on " << server_address << std::endl;
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string server_address("localhost:50051");
|
||||
|
||||
|
@ -4328,14 +4366,31 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
server_context ctx_server;
|
||||
BackendServiceImpl service(ctx_server);
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB
|
||||
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]()
|
||||
{
|
||||
RunServer(server_address, ctx_server);
|
||||
std::cout << "Server listening on " << server_address << std::endl;
|
||||
server->Wait();
|
||||
return 0;
|
||||
});
|
||||
|
||||
// clean up function, to be called before exit
|
||||
auto clean_up = [&server, &ctx_server]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
server->Shutdown();
|
||||
ctx_server.queue_results.terminate();
|
||||
llama_backend_free();
|
||||
};
|
||||
|
||||
|
||||
//);
|
||||
start_llama_server(ctx_server);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue