Make it functional

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-11-08 18:26:20 +01:00
parent 8b3c083c97
commit f0e265a96d

View file

@ -1811,7 +1811,24 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
////////////////////////////////
//////// LOCALAI
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
// The class has a llama instance that is shared across all RPCs
llama_server_context llama;
static void start_llama_server() {
// Wait for model to be loaded first
while (!loaded_model) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
bool running = true;
while (running)
{
running = llama.update_slots();
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
json parse_options(bool streaming, const backend::PredictOptions* predict, llama_server_context &llama)
{
@ -1951,7 +1968,15 @@ static void params_parse(const backend::ModelOptions* request,
params.n_threads = request->threads();
params.n_gpu_layers = request->ngpulayers();
params.n_batch = request->nbatch();
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
//params.n_parallel = 1;
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
if (env_parallel != NULL) {
params.n_parallel = std::stoi(env_parallel);
} else {
params.n_parallel = 1;
}
// TODO: Add yarn
if (!request->tensorsplit().empty()) {
@ -1985,8 +2010,6 @@ static void params_parse(const backend::ModelOptions* request,
params.embedding = request->embeddings();
}
// The class has a llama instance that is shared across all RPCs
llama_server_context llama;
// GRPC Server start
class BackendServiceImpl final : public backend::Backend::Service {
@ -2014,6 +2037,7 @@ public:
llama.initialize();
result->set_message("Loading succeeded");
result->set_success(true);
loaded_model = true;
return Status::OK;
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
@ -2032,7 +2056,10 @@ public:
});
backend::Reply reply;
reply.set_message(str.c_str());
// print it
std::string completion_text = result.result_json.value("content", "");
reply.set_message(completion_text);
// Send the reply
writer->Write(reply);
@ -2060,12 +2087,13 @@ public:
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
json data = parse_options(true, request, llama);
json data = parse_options(false, request, llama);
const int task_id = llama.request_completion(data, false, false);
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
reply->set_message(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace));
completion_text = result.result_json.value("content", "");
reply->set_message(completion_text);
}
else
{
@ -2118,17 +2146,10 @@ int main(int argc, char** argv) {
return 0;
});
{
bool running = true;
while (running)
{
running = llama.update_slots();
std::this_thread::sleep_for(std::chrono::milliseconds(1));
// print state
std::cout << running << std::endl;
}
}
//);
start_llama_server();
std::cout << "stopping" << std::endl;
t.join();