mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-21 11:04:59 +00:00
Make it functional
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
8b3c083c97
commit
f0e265a96d
1 changed files with 38 additions and 17 deletions
|
@ -1811,7 +1811,24 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
|
||||||
////////////////////////////////
|
////////////////////////////////
|
||||||
//////// LOCALAI
|
//////// LOCALAI
|
||||||
|
|
||||||
|
bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model
|
||||||
|
|
||||||
|
// The class has a llama instance that is shared across all RPCs
|
||||||
|
llama_server_context llama;
|
||||||
|
|
||||||
|
static void start_llama_server() {
|
||||||
|
// Wait for model to be loaded first
|
||||||
|
while (!loaded_model) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool running = true;
|
||||||
|
while (running)
|
||||||
|
{
|
||||||
|
running = llama.update_slots();
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
json parse_options(bool streaming, const backend::PredictOptions* predict, llama_server_context &llama)
|
json parse_options(bool streaming, const backend::PredictOptions* predict, llama_server_context &llama)
|
||||||
{
|
{
|
||||||
|
@ -1951,7 +1968,15 @@ static void params_parse(const backend::ModelOptions* request,
|
||||||
params.n_threads = request->threads();
|
params.n_threads = request->threads();
|
||||||
params.n_gpu_layers = request->ngpulayers();
|
params.n_gpu_layers = request->ngpulayers();
|
||||||
params.n_batch = request->nbatch();
|
params.n_batch = request->nbatch();
|
||||||
params.n_parallel = 1;
|
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
|
||||||
|
//params.n_parallel = 1;
|
||||||
|
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
|
||||||
|
if (env_parallel != NULL) {
|
||||||
|
params.n_parallel = std::stoi(env_parallel);
|
||||||
|
} else {
|
||||||
|
params.n_parallel = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Add yarn
|
// TODO: Add yarn
|
||||||
|
|
||||||
if (!request->tensorsplit().empty()) {
|
if (!request->tensorsplit().empty()) {
|
||||||
|
@ -1985,8 +2010,6 @@ static void params_parse(const backend::ModelOptions* request,
|
||||||
params.embedding = request->embeddings();
|
params.embedding = request->embeddings();
|
||||||
}
|
}
|
||||||
|
|
||||||
// The class has a llama instance that is shared across all RPCs
|
|
||||||
llama_server_context llama;
|
|
||||||
|
|
||||||
// GRPC Server start
|
// GRPC Server start
|
||||||
class BackendServiceImpl final : public backend::Backend::Service {
|
class BackendServiceImpl final : public backend::Backend::Service {
|
||||||
|
@ -2014,6 +2037,7 @@ public:
|
||||||
llama.initialize();
|
llama.initialize();
|
||||||
result->set_message("Loading succeeded");
|
result->set_message("Loading succeeded");
|
||||||
result->set_success(true);
|
result->set_success(true);
|
||||||
|
loaded_model = true;
|
||||||
return Status::OK;
|
return Status::OK;
|
||||||
}
|
}
|
||||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||||
|
@ -2031,8 +2055,11 @@ public:
|
||||||
{ "to_send", str }
|
{ "to_send", str }
|
||||||
});
|
});
|
||||||
|
|
||||||
backend::Reply reply;
|
backend::Reply reply;
|
||||||
reply.set_message(str.c_str());
|
// print it
|
||||||
|
std::string completion_text = result.result_json.value("content", "");
|
||||||
|
|
||||||
|
reply.set_message(completion_text);
|
||||||
|
|
||||||
// Send the reply
|
// Send the reply
|
||||||
writer->Write(reply);
|
writer->Write(reply);
|
||||||
|
@ -2060,12 +2087,13 @@ public:
|
||||||
|
|
||||||
|
|
||||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
||||||
json data = parse_options(true, request, llama);
|
json data = parse_options(false, request, llama);
|
||||||
const int task_id = llama.request_completion(data, false, false);
|
const int task_id = llama.request_completion(data, false, false);
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
if (!result.error && result.stop) {
|
if (!result.error && result.stop) {
|
||||||
reply->set_message(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace));
|
completion_text = result.result_json.value("content", "");
|
||||||
|
reply->set_message(completion_text);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -2118,17 +2146,10 @@ int main(int argc, char** argv) {
|
||||||
return 0;
|
return 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
{
|
|
||||||
bool running = true;
|
|
||||||
while (running)
|
|
||||||
{
|
|
||||||
running = llama.update_slots();
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
|
||||||
// print state
|
|
||||||
std::cout << running << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//);
|
//);
|
||||||
|
start_llama_server();
|
||||||
|
std::cout << "stopping" << std::endl;
|
||||||
|
|
||||||
t.join();
|
t.join();
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue