This commit is contained in:
Ettore Di Giacinto 2023-11-07 19:12:58 +01:00
parent cf0d23828e
commit 8b3c083c97

View file

@ -21,6 +21,10 @@
#include "backend.grpc.pb.h"
// include std::regex
#include <cstddef>
#include <thread>
#include <mutex>
#include <chrono>
#include <regex>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
@ -1779,22 +1783,6 @@ static json format_detokenized_response(std::string content)
}
static void log_server_request(const httplib::Request &req, const httplib::Response &res)
{
LOG_INFO("request", {
{"remote_addr", req.remote_addr},
{"remote_port", req.remote_port},
{"status", res.status},
{"method", req.method},
{"path", req.path},
{"params", req.params},
});
LOG_VERBOSE("request", {
{"request", req.body},
{"response", res.body},
});
}
struct token_translator
{
@ -1823,72 +1811,131 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
////////////////////////////////
//////// LOCALAI
static void parse_options_completion(bool streaming,const backend::PredictOptions* predict, llama_server_context &llama)
json parse_options(bool streaming, const backend::PredictOptions* predict, llama_server_context &llama)
{
// https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L673
gpt_params default_params;
// This is for example a slot data from the json data
// slot->params.stream = json_value(data, "stream", false);
// slot->params.cache_prompt = json_value(data, "cache_prompt", false);
// slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
// slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
// slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
// slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
// slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
// slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
// slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
// slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
// slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
// slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
// slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
// slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
// slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
// slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
// slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
// slot->params.seed = json_value(data, "seed", default_params.seed);
// slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
// slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
llama.stream = streaming;
llama.params.n_predict = predict->tokens() == 0 ? -1 : predict->tokens();
llama.params.sparams.top_k = predict->topk();
llama.params.sparams.top_p = predict->topp();
llama.params.sparams.tfs_z = predict->tailfreesamplingz();
llama.params.sparams.typical_p = predict->typicalp();
llama.params.sparams.penalty_last_n = predict->repeat();
llama.params.sparams.temp = predict->temperature();
llama.params.sparams.penalty_repeat = predict->penalty();
llama.params.sparams.penalty_present = predict->presencepenalty();
llama.params.sparams.penalty_freq = predict->frequencypenalty();
llama.params.sparams.mirostat = predict->mirostat();
llama.params.sparams.mirostat_tau = predict->mirostattau();
llama.params.sparams.mirostat_eta = predict->mirostateta();
llama.params.sparams.penalize_nl = predict->penalizenl();
llama.params.n_keep = predict->nkeep();
llama.params.seed = predict->seed();
llama.params.sparams.grammar = predict->grammar();
// llama.params.n_probs = predict->
llama.params.prompt = predict->prompt();
// Create now a json data from the prediction options instead
//
json data;
data["stream"] = streaming;
data["cache_prompt"] = predict->promptcacheall();
data["n_predict"] = predict->tokens() == 0 ? -1 : predict->tokens();
data["top_k"] = predict->topk();
data["top_p"] = predict->topp();
data["tfs_z"] = predict->tailfreesamplingz();
data["typical_p"] = predict->typicalp();
data["temperature"] = predict->temperature();
data["repeat_last_n"] = predict->repeat();
data["repeat_penalty"] = predict->penalty();
data["frequency_penalty"] = predict->frequencypenalty();
data["presence_penalty"] = predict->presencepenalty();
data["mirostat"] = predict->mirostat();
data["mirostat_tau"] = predict->mirostattau();
data["mirostat_eta"] = predict->mirostateta();
data["penalize_nl"] = predict->penalizenl();
data["n_keep"] = predict->nkeep();
data["seed"] = predict->seed();
data["grammar"] = predict->grammar();
data["prompt"] = predict->prompt();
data["ignore_eos"] = predict->ignoreeos();
llama.params.sparams.logit_bias.clear();
data["stop"] = predict->stopprompts();
// data["n_probs"] = predict->nprobs();
//TODO: images,
if (predict->ignoreeos())
{
llama.params.sparams.logit_bias[llama_token_eos(llama.model)] = -INFINITY;
}
// const auto &logit_bias = body.find("logit_bias");
// if (logit_bias != body.end() && logit_bias->is_array())
// {
// const int n_vocab = llama_n_vocab(llama.model);
// for (const auto &el : *logit_bias)
// {
// if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
// {
// llama_token tok = el[0].get<llama_token>();
// if (tok >= 0 && tok < n_vocab)
// {
// if (el[1].is_number())
// {
// llama.params.logit_bias[tok] = el[1].get<float>();
// }
// else if (el[1].is_boolean() && !el[1].get<bool>())
// {
// llama.params.logit_bias[tok] = -INFINITY;
// }
// }
// }
// }
// }
llama.params.antiprompt.clear();
for (const std::string& stopPrompt : predict->stopprompts()) {
if (!stopPrompt.empty())
{
llama.params.antiprompt.push_back(stopPrompt);
}
}
return data;
}
// static void parse_options_completion(bool streaming,const backend::PredictOptions* predict, llama_server_context &llama)
// {
// // https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L673
// gpt_params default_params;
// llama.stream = streaming;
// llama.params.n_predict = predict->tokens() == 0 ? -1 : predict->tokens();
// llama.params.sparams.top_k = predict->topk();
// llama.params.sparams.top_p = predict->topp();
// llama.params.sparams.tfs_z = predict->tailfreesamplingz();
// llama.params.sparams.typical_p = predict->typicalp();
// llama.params.sparams.penalty_last_n = predict->repeat();
// llama.params.sparams.temp = predict->temperature();
// llama.params.sparams.penalty_repeat = predict->penalty();
// llama.params.sparams.penalty_present = predict->presencepenalty();
// llama.params.sparams.penalty_freq = predict->frequencypenalty();
// llama.params.sparams.mirostat = predict->mirostat();
// llama.params.sparams.mirostat_tau = predict->mirostattau();
// llama.params.sparams.mirostat_eta = predict->mirostateta();
// llama.params.sparams.penalize_nl = predict->penalizenl();
// llama.params.n_keep = predict->nkeep();
// llama.params.seed = predict->seed();
// llama.params.sparams.grammar = predict->grammar();
// // llama.params.n_probs = predict->
// llama.params.prompt = predict->prompt();
// llama.params.sparams.logit_bias.clear();
// if (predict->ignoreeos())
// {
// llama.params.sparams.logit_bias[llama_token_eos(llama.model)] = -INFINITY;
// }
// // const auto &logit_bias = body.find("logit_bias");
// // if (logit_bias != body.end() && logit_bias->is_array())
// // {
// // const int n_vocab = llama_n_vocab(llama.model);
// // for (const auto &el : *logit_bias)
// // {
// // if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
// // {
// // llama_token tok = el[0].get<llama_token>();
// // if (tok >= 0 && tok < n_vocab)
// // {
// // if (el[1].is_number())
// // {
// // llama.params.logit_bias[tok] = el[1].get<float>();
// // }
// // else if (el[1].is_boolean() && !el[1].get<bool>())
// // {
// // llama.params.logit_bias[tok] = -INFINITY;
// // }
// // }
// // }
// // }
// // }
// llama.params.antiprompt.clear();
// for (const std::string& stopPrompt : predict->stopprompts()) {
// if (!stopPrompt.empty())
// {
// llama.params.antiprompt.push_back(stopPrompt);
// }
// }
// }
static void params_parse(const backend::ModelOptions* request,
@ -1904,6 +1951,7 @@ static void params_parse(const backend::ModelOptions* request,
params.n_threads = request->threads();
params.n_gpu_layers = request->ngpulayers();
params.n_batch = request->nbatch();
params.n_parallel = 1;
// TODO: Add yarn
if (!request->tensorsplit().empty()) {
@ -1937,12 +1985,11 @@ static void params_parse(const backend::ModelOptions* request,
params.embedding = request->embeddings();
}
// The class has a llama instance that is shared across all RPCs
llama_server_context llama;
// GRPC Server start
class BackendServiceImpl final : public backend::Backend::Service {
// The class has a llama instance that is shared across all RPCs
llama_server_context llama;
public:
grpc::Status Health(ServerContext* context, const backend::HealthMessage* request, backend::Reply* reply) {
// Implement Health RPC
@ -1970,126 +2017,61 @@ public:
return Status::OK;
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
// Implement the streaming logic here based on the request options
// You can use writer->Write(response) to send a reply to the client
// and return grpc::Status::OK when the operation is complete.
auto lock = llama.lock();
json data = parse_options(true, request, llama);
const int task_id = llama.request_completion(data, false, false);
while (true)
{
task_result result = llama.next_result(task_id);
if (!result.error) {
const std::string str =
"data: " +
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
llama.rewind();
llama_reset_timings(llama.ctx);
parse_options_completion(false, request, llama);
llama.initSampling();
llama.loadPrompt(request->prompt());
llama.beginCompletion();
size_t sent_count = 0;
size_t sent_token_probs_index = 0;
while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion();
if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
continue;
}
const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
size_t pos = std::min(sent_count, llama.generated_text.size());
const std::string str_test = llama.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
is_stop_full = false;
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}
if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {};
if (llama.params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
backend::Reply reply;
reply.set_message(to_send);
backend::Reply reply;
reply.set_message(str.c_str());
// Send the reply
writer->Write(reply);
if (result.stop) {
break;
}
} else {
break;
}
}
return grpc::Status::OK;
llama_print_timings(llama.ctx);
// auto on_complete = [task_id, &llama] (bool)
// {
// // cancel
// llama.request_cancel(task_id);
// };
llama.mutex.unlock();
lock.release();
return grpc::Status::OK;
}
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
auto lock = llama.lock();
llama.rewind();
llama_reset_timings(llama.ctx);
parse_options_completion(false, request, llama);
llama.initSampling();
llama.loadPrompt(request->prompt());
llama.beginCompletion();
if (llama.params.n_beams) {
// Fill llama.generated_token_probs vector with final beam.
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
llama.n_past, llama.n_remain);
// Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama);
} else {
size_t stop_pos = std::string::npos;
while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok);
stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
}
if (stop_pos == std::string::npos) {
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
}
if (stop_pos != std::string::npos) {
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
llama.generated_text.end());
}
json data = parse_options(true, request, llama);
const int task_id = llama.request_completion(data, false, false);
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
reply->set_message(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace));
}
else
{
return grpc::Status::OK;
}
auto probs = llama.generated_token_probs;
if (llama.params.sparams.n_probs > 0 && llama.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
}
reply->set_message(llama.generated_text);
return grpc::Status::OK;
}
};
@ -2129,6 +2111,27 @@ int main(int argc, char** argv) {
}
}
RunServer(server_address);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
RunServer(server_address);
return 0;
});
{
bool running = true;
while (running)
{
running = llama.update_slots();
std::this_thread::sleep_for(std::chrono::milliseconds(1));
// print state
std::cout << running << std::endl;
}
}
//);
t.join();
llama_backend_free();
return 0;
}