diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index ce852479..977e2fda 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -1811,7 +1811,24 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con //////////////////////////////// //////// LOCALAI +bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model +// The class has a llama instance that is shared across all RPCs +llama_server_context llama; + +static void start_llama_server() { + // Wait for model to be loaded first + while (!loaded_model) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + bool running = true; + while (running) + { + running = llama.update_slots(); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} json parse_options(bool streaming, const backend::PredictOptions* predict, llama_server_context &llama) { @@ -1951,7 +1968,15 @@ static void params_parse(const backend::ModelOptions* request, params.n_threads = request->threads(); params.n_gpu_layers = request->ngpulayers(); params.n_batch = request->nbatch(); - params.n_parallel = 1; + // Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1 + //params.n_parallel = 1; + const char *env_parallel = std::getenv("LLAMACPP_PARALLEL"); + if (env_parallel != NULL) { + params.n_parallel = std::stoi(env_parallel); + } else { + params.n_parallel = 1; + } + // TODO: Add yarn if (!request->tensorsplit().empty()) { @@ -1985,8 +2010,6 @@ static void params_parse(const backend::ModelOptions* request, params.embedding = request->embeddings(); } -// The class has a llama instance that is shared across all RPCs -llama_server_context llama; // GRPC Server start class BackendServiceImpl final : public backend::Backend::Service { @@ -2014,6 +2037,7 @@ public: llama.initialize(); result->set_message("Loading succeeded"); result->set_success(true); + loaded_model = true; return Status::OK; } grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { @@ -2031,8 +2055,11 @@ public: { "to_send", str } }); - backend::Reply reply; - reply.set_message(str.c_str()); + backend::Reply reply; + // print it + std::string completion_text = result.result_json.value("content", ""); + + reply.set_message(completion_text); // Send the reply writer->Write(reply); @@ -2060,12 +2087,13 @@ public: grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { - json data = parse_options(true, request, llama); + json data = parse_options(false, request, llama); const int task_id = llama.request_completion(data, false, false); std::string completion_text; task_result result = llama.next_result(task_id); if (!result.error && result.stop) { - reply->set_message(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace)); + completion_text = result.result_json.value("content", ""); + reply->set_message(completion_text); } else { @@ -2118,17 +2146,10 @@ int main(int argc, char** argv) { return 0; }); - { - bool running = true; - while (running) - { - running = llama.update_slots(); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - // print state - std::cout << running << std::endl; - } - } + //); + start_llama_server(); + std::cout << "stopping" << std::endl; t.join();