diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index c57f4070..aeb8d409 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -1,198 +1,1271 @@ -// llama.cpp gRPC C++ backend server -// -// Ettore Di Giacinto and llama.cpp authors -// -// This is a gRPC server for llama.cpp compatible with the LocalAI proto -// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP (https://github.com/ggerganov/llama.cpp/tree/master/examples/server), -// but modified to work with gRPC -// - -#include -#include -#include -#include -#include "mtmd.h" -#include "log.h" -#include "stb_image.h" -#include "common.h" -#include "json.hpp" -#include "llama.h" -#include "backend.pb.h" -#include "backend.grpc.pb.h" #include "utils.hpp" + +#include "arg.h" +#include "common.h" +#include "json-schema-to-grammar.h" +#include "llama.h" +#include "log.h" #include "sampling.h" -// include std::regex -#include -#include -#include +#include "speculative.h" +#include "mtmd.h" + +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" + +#include #include -#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include -#include -#include + using grpc::Server; using grpc::ServerBuilder; using grpc::ServerContext; using grpc::Status; +using json = nlohmann::ordered_json; +constexpr int HTTP_POLLING_SECONDS = 1; -using backend::HealthMessage; - - -///// LLAMA.CPP server code below - -using json = nlohmann::json; - -struct server_params -{ - std::string hostname = "127.0.0.1"; - std::vector api_keys; - std::string public_path = "tools/server/public"; - std::string chat_template = ""; - int32_t port = 8080; - int32_t read_timeout = 600; - int32_t write_timeout = 600; - bool slots_endpoint = true; - bool metrics_endpoint = false; +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, }; -bool server_verbose = false; -bool server_log_json = true; +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; -static size_t common_part(const std::vector &a, const std::vector &b) -{ - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) - { +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +struct slot_params { + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + bool ignore_eos = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + json to_json() const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + server_grammar_trigger ct(std::move(trigger)); + grammar_triggers.push_back(ct.to_json()); + } + + return json { + {"n_predict", n_predict}, // Server configured n_predict + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_format)}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; } - return i; -} - -enum stop_type -{ - STOP_FULL, - STOP_PARTIAL, }; -static bool ends_with(const std::string &str, const std::string &suffix) -{ - return str.size() >= suffix.size() && - 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); -} +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) -static size_t find_partial_stop_string(const std::string &stop, - const std::string &text) -{ - if (!text.empty() && !stop.empty()) - { - const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) + server_task_type type; + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + server_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 0); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY { - if (stop[char_index] == text_last_char) - { - const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) - { - return text.size() - char_index - 1; + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); } } } - } - return std::string::npos; -} -// TODO: reuse llama_detokenize -template -static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) -{ - std::string ret; - for (; begin != end; ++begin) - { - ret += common_token_to_piece(ctx, *begin); - } - return ret; -} - -// format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) -{ - std::string out = token == -1 ? "" : common_token_to_piece(ctx, token); - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) - { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; - } - return out; -} - -// Adds an RPC server -// https://github.com/ggerganov/llama.cpp/compare/4dbc8b9cb71876e005724f4e8f73a3544646bcf5..3edfa7d3753c29e44b964c0ff424d2ea8d5fdee6 -static void add_rpc_devices(std::string servers) { - auto rpc_servers = string_split(servers, ','); - if (rpc_servers.empty()) { - throw std::invalid_argument("no RPC servers specified"); - } - ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); - if (!rpc_reg) { - throw std::invalid_argument("failed to find RPC backend"); - } - typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); - ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - throw std::invalid_argument("failed to find RPC device add function"); - } - for (const auto & server : rpc_servers) { - ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (dev) { - ggml_backend_device_register(dev); + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } } else { - throw std::invalid_argument("failed to register RPC device"); + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str()); + } else { + params.oaicompat_chat_format = defaults.oaicompat_chat_format; + } + } + + { + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + params.sampling.grammar_triggers.push_back(std::move(trigger)); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + params.sampling.grammar_triggers.push_back(std::move(ct.value)); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + + json to_json() const { + json base = { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; + } +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; } } -// convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context *ctx, const std::vector &probs) -{ - json out = json::array(); - for (const auto &prob : probs) - { +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const { json probs_for_token = json::array(); - for (const auto &p : prob.probs) - { - std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json - { - {"tok_str", tok_str}, - {"prob", p.prob}, + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, }); } - std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json{ - {"content", tok_str}, - {"probs", probs_for_token}, - }); + return probs_for_token; } - return out; + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; + } + + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + SRV_DBG("Parsing chat message: %s\n", content.c_str()); + msg = common_chat_parse(content, oaicompat_chat_format); + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } else { + msg.content = content; + } + + json message { + {"role", "assistant"}, + }; + if (!msg.reasoning_content.empty()) { + message["reasoning_content"] = msg.reasoning_content; + } + if (msg.content.empty() && !msg.tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = msg.content; + } + if (!msg.tool_calls.empty()) { + auto tool_calls = json::array(); + for (const auto & tc : msg.tool_calls) { + tool_calls.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo). + // We only generate a random id for the ones that don't generate one by themselves + // (they also won't get to see it as their template likely doesn't use it, so it's all for the client) + {"id", tc.id.empty() ? gen_tool_call_id() : tc.id}, + }); + } + message["tool_calls"] = tool_calls; + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", message}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choice = json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }; + + json ret = json { + {"choices", json::array({choice})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + // extra fields for debugging purposes + if (verbose) { + ret["__verbose"] = to_json_non_oaicompat(); + } + + return ret; + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + completion_token_output prob_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + bool first = n_decoded == 0; + std::time_t t = std::time(0); + json choices; + + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json { + {"content", content}}} + }})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json { + {"content", content}, + }}, + }}); + } + + GGML_ASSERT(choices.size() >= 1); + + if (prob_output.probs.size() > 0) { + choices[0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"} + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return std::vector({ret}); + } +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; } -struct llama_client_slot -{ +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + int32_t kv_cache_tokens_count; + int32_t kv_cache_used_cells; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "kv_cache_tokens_count", kv_cache_tokens_count }, + { "kv_cache_used_cells", kv_cache_used_cells }, + + { "slots", slots_data }, + }; + } +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } +}; + +struct server_slot { int id; - int task_id = -1; + int id_task = -1; + + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + + llama_batch batch_spec = {}; + + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + common_speculative * spec = nullptr; + + std::vector lora; + + // the index relative to completion multi-task request + size_t index = 0; struct slot_params params; - slot_state state = IDLE; - slot_command command = NONE; + slot_state state = SLOT_STATE_IDLE; // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -203,200 +1276,270 @@ struct llama_client_slot int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; - int32_t n_predict = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - int32_t num_prompt_tokens = 0; - int32_t num_prompt_tokens_processed = 0; + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated + int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens_processed = 0; - json prompt; - json data; + // input prompt tokens + server_tokens prompt_tokens; + + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + + server_tokens cache_tokens; - std::string generated_text; - llama_token sampled; - std::vector cache_tokens; std::vector generated_token_probs; - bool infill = false; - bool embedding = false; bool has_next_token = true; - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - - bool oaicompat = false; - std::string oaicompat_model; + bool has_new_line = false; + bool truncated = false; + stop_type stop; std::string stopping_word; // sampling - struct common_params_sampling sparams; - common_sampler *ctx_sampling = nullptr; + json json_schema; - int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width + struct common_sampler * smpl = nullptr; - int32_t n_past_se = 0; // self-extend + llama_token sampled; - // multimodal - mtmd_context * mctx = nullptr; + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; // stats - size_t sent_count = 0; - size_t sent_token_probs_index = 0; + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; - int64_t t_start_genereration; + int64_t t_start_generation; double t_prompt_processing; // ms - double t_token_generation; // ms + double t_token_generation; // ms - // multitasks - int multitask_id = -1; + std::function callback_on_release; + + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted void reset() { - num_prompt_tokens = 0; - generated_text = ""; - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - n_past = 0; - sent_count = 0; - sent_token_probs_index = 0; - infill = false; - ga_i = 0; - n_past_se = 0; + SLT_DBG(*this, "%s", "\n"); + n_prompt_tokens = 0; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + task_type = SERVER_TASK_TYPE_COMPLETION; + + generated_tokens.clear(); generated_token_probs.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; } - bool has_budget(common_params &global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) - { + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + + bool can_batch_with(server_slot & other_slot) const { + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (params.n_predict != -1) - { + if (params.n_predict != -1) { n_remaining = params.n_predict - n_decoded; - } - else if (global_params.n_predict != -1) - { + } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } return n_remaining > 0; // no budget } - bool available() const { - return state == IDLE && command == NONE; - } - bool is_processing() const { - return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; + return state != SLOT_STATE_IDLE; } - void add_token_string(const completion_token_output &token) { - if (command == RELEASE) - { + bool can_speculate() const { + return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + } + + void add_token(const completion_token_output & token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); return; } - cache_tokens.push_back(token.tok); generated_token_probs.push_back(token); } void release() { - if (state == PROCESSING) - { - t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; - command = RELEASE; + if (is_processing()) { + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + + t_last_used = ggml_time_us(); + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + callback_on_release(id); } } - json get_formated_timings() { - return json - { - {"prompt_n", num_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, - {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, - }; + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + + return timings; + } + + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + size_t stop_pos = std::string::npos; + + for (const std::string & word : params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = find_partial_stop_string(word, text); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; } void print_timings() const { - char buffer[512]; - double t_token = t_prompt_processing / num_prompt_tokens_processed; - double n_tokens_second = 1e3 / t_prompt_processing * num_prompt_tokens_processed; - sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", - t_prompt_processing, num_prompt_tokens_processed, - t_token, n_tokens_second); - LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, - {"t_prompt_processing", t_prompt_processing}, - {"num_prompt_tokens_processed", num_prompt_tokens_processed}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - t_token = t_token_generation / n_decoded; - n_tokens_second = 1e3 / t_token_generation * n_decoded; - sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", - t_token_generation, n_decoded, - t_token, n_tokens_second); - LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, - {"t_token_generation", t_token_generation}, - {"n_decoded", n_decoded}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; - sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation); - LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, - {"t_prompt_processing", t_prompt_processing}, - {"t_token_generation", t_token_generation}, - {"t_total", t_prompt_processing + t_token_generation}, - }); + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } + } + + json to_json() const { + return json { + {"id", id}, + {"id_task", id_task}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, + {"params", params.to_json()}, + {"prompt", prompt_tokens.detokenize(ctx, true)}, + {"next_token", + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + } + }, + }; } }; -struct llama_metrics { +struct server_metrics { + int64_t t_start = 0; + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; uint64_t n_prompt_tokens_processed = 0; uint64_t t_prompt_processing = 0; - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; - void on_prompt_eval(const llama_client_slot &slot) { - n_prompt_tokens_processed_total += slot.num_prompt_tokens_processed; - - n_prompt_tokens_processed += slot.num_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; + void init() { + t_start = ggml_time_us(); } - void on_prediction(const llama_client_slot &slot) { - n_tokens_predicted_total += slot.n_decoded; + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; + } - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + } } void reset_bucket() { @@ -407,1441 +1550,1766 @@ struct llama_metrics { } }; -struct llava_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; +struct server_queue { + int id = 0; + bool running; + + // queues + std::deque queue_tasks; + std::deque queue_tasks_deferred; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_update_slots; + + // Add a new task to the end of the queue + int post(server_task && task, bool front = false) { + std::unique_lock lock(mutex_tasks); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); } + const int task_id = task.id; + QUE_DBG("new task, id = %d, front = %d\n", task_id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + condition_tasks.notify_one(); + return task_id; + } + + // multi-task version of post() + int post(std::vector && tasks, bool front = false) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; + } + + // Add a new task, but defer until one slot is available + void defer(server_task && task) { + std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); + queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); + } + + // Get the next id for creating a new task + int get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + return new_id; + } + + // Register function to process a new task + void on_new_task(std::function callback) { + callback_new_task = std::move(callback); + } + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback) { + callback_update_slots = std::move(callback); + } + + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task() { + std::unique_lock lock(mutex_tasks); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } + condition_tasks.notify_one(); + } + + // end the start_loop routine + void terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); + } + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop() { + running = true; + + while (true) { + QUE_DBG("%s", "processing new tasks\n"); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = std::move(queue_tasks.front()); + queue_tasks.pop_front(); + lock.unlock(); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); + } + + // all tasks in the current loop is processed, slots data is now ready + QUE_DBG("%s", "update slots\n"); + + callback_update_slots(); + + QUE_DBG("%s", "waiting for new tasks\n"); + { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); + } + } + } + } + +private: + void cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id_target == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); } }; -struct llama_server_context -{ - llama_model *model = nullptr; - llama_context *ctx = nullptr; - const llama_vocab * vocab = nullptr; +struct server_response { + bool running = true; + + // for keeping track of all tasks waiting for the result + std::unordered_set waiting_task_ids; + + // the main result queue (using ptr for polymorphism) + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); + } + + void add_waiting_tasks(const std::vector & tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } + } + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); + } + + void remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } + } + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set & id_tasks) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + return !queue_results.empty(); + }); + + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here + } + + // single-task version of recv() + server_task_result_ptr recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); + } + + // Send a new result to a waiting id_task + void send(server_task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); + + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); + condition_results.notify_all(); + return; + } + } + } + + // terminate the waiting loop + void terminate() { + running = false; + condition_results.notify_all(); + } +}; + +struct server_context { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; // multimodal mtmd_context * mctx = nullptr; - clip_ctx *clp_ctx = nullptr; + const llama_vocab * vocab = nullptr; - common_params params; + llama_model * model_dft = nullptr; - llama_batch batch; + llama_context_params cparams_dft; - bool multimodal = false; - bool clean_kv_cache = true; - bool all_slots_are_idle = false; - bool add_bos_token = true; - bool has_eos_token = true; - bool has_gpu = false; + llama_batch batch {}; - bool grammar_lazy = false; - std::vector grammar_triggers; + bool clean_kv_cache = true; + bool add_bos_token = true; + bool has_eos_token = false; - int32_t n_ctx; // total context for all clients / slots - - // system prompt - bool system_need_update = false; - - std::string system_prompt; - std::vector system_tokens; - - std::string name_user; // this should be the antiprompt - std::string name_assistant; + int32_t n_ctx; // total context for all clients / slots // slots / clients - std::vector slots; + std::vector slots; json default_generation_settings_for_props; - llama_server_queue queue_tasks; - llama_server_response queue_results; + server_queue queue_tasks; + server_response queue_results; - llama_metrics metrics; + server_metrics metrics; - ~llama_server_context() - { - if (mctx) { - mtmd_free(mctx); - mctx = nullptr; - } - if (ctx) - { - llama_free(ctx); - ctx = nullptr; - } - if (model) - { - llama_free_model(model); - model = nullptr; + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + + ~server_context() { + mtmd_free(mctx); + + // Clear any sampling context + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); } + + llama_batch_free(batch); } - bool load_model(const common_params ¶ms_) - { - params = params_; - if (!params.mmproj.path.empty()) { - multimodal = true; - LOG_INFO("Multi Modal Mode Enabled", {}); - mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = has_gpu; - mparams.print_timings = false; - mparams.n_threads = params.cpuparams.n_threads; - mparams.verbosity = GGML_LOG_LEVEL_INFO; - mctx = mtmd_init_from_file(params.mmproj.path.c_str(), model, mparams); - if (mctx == nullptr) { - LOG_ERR("failed to load multimodal model, '%s'\n", params.mmproj.path.c_str()); - return false; - } + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.path.c_str()); - if (params.n_ctx < 2048) { // request larger context for the image embedding - params.n_ctx = 2048; - } - } + params_base = params; - common_init_result common_init = common_init_from_params(params); - model = common_init.model.release(); - ctx = common_init.context.release(); - if (model == nullptr) - { - LOG_ERR("unable to load model: %s", params.model.path.c_str()); + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - if (multimodal) { - const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); - const int n_embd_llm = llama_model_n_embd(model); - if (n_embd_clip != n_embd_llm) { - LOG("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm); - llama_free(ctx); - llama_free_model(model); - return false; - } - } - vocab = llama_model_get_vocab(model); + n_ctx = llama_n_ctx(ctx); add_bos_token = llama_vocab_get_add_bos(vocab); has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.model = params_base.speculative.model; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + + // force F16 KV cache for the draft model for extra performance + params_dft.cache_type_k = GGML_TYPE_F16; + params_dft.cache_type_v = GGML_TYPE_F16; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); + + return false; + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja); + } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + return true; } - llama_client_slot* get_active_slot() { - for (llama_client_slot& slot : slots) { - // Check if the slot is currently processing - if (slot.is_processing()) { - return &slot; // Return the active slot - } - } - return nullptr; // No active slot found - } + void init() { + const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; - void initialize() { - // create slots - all_slots_are_idle = true; + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - const int32_t n_ctx_slot = n_ctx / params.n_parallel; - - LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); - for (int i = 0; i < params.n_parallel; i++) - { - llama_client_slot slot; + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; slot.id = i; + slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.n_predict = params.n_predict; + slot.n_predict = params_base.n_predict; slot.mctx = mctx; - //slot.cache_tokens.has_mtmd = mctx != nullptr; + slot.cache_tokens.has_mtmd = mctx != nullptr; - LOG_INFO("new slot", { - {"slot_id", slot.id}, - {"n_ctx_slot", slot.n_ctx} - }); + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - const int ga_n = params.grp_attn_n; - const int ga_w = params.grp_attn_w; + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } - if (ga_n != 1) { - GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT - GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT - //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT - //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT - - LOG_INFO("slot self-extend", { - {"slot_id", slot.id}, - {"ga_n", ga_n}, - {"ga_w", ga_w} - }); + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } } - slot.ga_i = 0; - slot.ga_n = ga_n; - slot.ga_w = ga_w; + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); + + slot.params.sampling = params_base.sampling; + + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; slot.reset(); - slots.push_back(slot); + slots.push_back(std::move(slot)); } - default_generation_settings_for_props = get_formated_generation(slots.front()); - default_generation_settings_for_props["seed"] = -1; + default_generation_settings_for_props = slots[0].to_json(); - batch = llama_batch_init(n_ctx, 0, params.n_parallel); - } - - std::vector tokenize(json &data, const json & json_prompt, bool add_bos) const - { - mtmd::bitmaps bitmaps; - std::vector inputs; - - if (mctx != nullptr) + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { - const auto &images_data = data.find("image_data"); - if (images_data != data.end() && images_data->is_array()) - { - for (const auto &img : *images_data) - { - const std::vector image_buffer = base64_decode(img["data"].get()); - - mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(image_buffer.data(), image_buffer.size())); - if (!bmp.ptr) { - throw std::runtime_error("Failed to load image"); - } - // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); - bmp.set_id(hash.c_str()); - bitmaps.entries.push_back(std::move(bmp)); - } - } - - // multimodal - std::string prompt_str = json_prompt.template get(); - mtmd_input_text inp_txt = { - prompt_str.c_str(), - /* add_special */ true, - /* parse_special */ true, - }; - mtmd::input_chunks chunks(mtmd_input_chunks_init()); - auto bitmaps_c_ptr = bitmaps.c_ptr(); - int32_t tokenized = mtmd_tokenize(mctx, - chunks.ptr.get(), - &inp_txt, - bitmaps_c_ptr.data(), - bitmaps_c_ptr.size()); - if (tokenized != 0) { - throw std::runtime_error("Failed to tokenize prompt"); - } - - server_tokens tmp(chunks, true); - inputs.push_back(std::move(tmp)); - } else { - // non-multimodal version - auto tokenized_prompts = tokenize_input_prompts(vocab, json_prompt, true, true); - for (auto & p : tokenized_prompts) { - auto tmp = server_tokens(p, mctx != nullptr); - inputs.push_back(std::move(tmp)); - } + const int32_t n_batch = llama_n_batch(ctx); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } - return inputs; + metrics.init(); } - llama_client_slot* get_slot(int id) { - int64_t t_last = ggml_time_us(); - llama_client_slot *last_used = nullptr; - - for (llama_client_slot & slot : slots) - { - if (slot.id == id && slot.available()) - { + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { return &slot; } - - if (slot.available() && slot.t_last_used < t_last) - { - last_used = &slot; - t_last = slot.t_last_used; - } } - return last_used; + return nullptr; } - bool launch_slot_with_data(llama_client_slot* &slot, json data) { - slot_params default_params; - common_params_sampling default_sparams; - - slot->params.stream = json_value(data, "stream", false); - slot->params.cache_prompt = json_value(data, "cache_prompt", false); - slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); - slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot->sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); - slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot->sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot->sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); - slot->sparams.seed = json_value(data, "seed", default_sparams.seed); - slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - slot->sparams.grammar_triggers = grammar_triggers; - slot->sparams.grammar_lazy = grammar_lazy; + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; - slot->data = data; + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int lcs_len = 0; + float similarity = 0; - if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // skip the slot if it does not contains cached tokens + if (slot.cache_tokens.empty()) { + continue; + } + + // length of the Longest Common Subsequence between the current slot's prompt and the input prompt + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); + + // fraction of the common subsequence length compared to the current slot's prompt length + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + + // select the current slot if the criteria match + if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { + lcs_len = cur_lcs_len; + similarity = cur_similarity; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = ggml_time_us(); + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); + } + } + + return ret; + } + + bool launch_slot_with_task(server_slot & slot, server_task && task) { + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); + + if (!are_lora_equal(slot.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = slot.params.lora; + } + + if (!slot.prompt_tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? - LOG_WARNING("Max tokens to predict exceeds server configuration", { - {"params.n_predict", slot->params.n_predict}, - {"slot.n_predict", slot->n_predict}, - }); - slot->params.n_predict = slot->n_predict; + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); + slot.params.n_predict = slot.n_predict; } - // infill - if (data.count("input_prefix") != 0) - { - slot->params.input_prefix = data["input_prefix"]; - } - else - { - slot->params.input_prefix = ""; + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); } - - if (data.count("input_suffix") != 0) { - slot->params.input_suffix = data["input_suffix"]; - } - else - { - slot->params.input_suffix = ""; - } + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } - if (data.count("prompt") != 0) - { - slot->prompt = data["prompt"]; - } - else - { - slot->prompt = ""; - } - - if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot->sparams.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); - } - - slot->sparams.logit_bias.clear(); - - const auto &logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) - { - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto &el : *logit_bias) - { - if (el.is_array() && el.size() == 2) - { - float bias; - if (el[1].is_number()) - { - bias = el[1].get(); - } - else if (el[1].is_boolean() && !el[1].get()) - { - bias = -INFINITY; - } - else - { - continue; - } - - if (el[0].is_number_integer()) - { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) - { - slot->sparams.logit_bias.push_back({tok, bias}); - } - } - else if (el[0].is_string()) - { - auto toks = common_tokenize(vocab, el[0].get(), false); - for (auto tok : toks) - { - slot->sparams.logit_bias.push_back({tok, bias}); - } - } - } + slot.smpl = common_sampler_init(model, slot.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; } } - - slot->params.antiprompt.clear(); - const auto &stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) - { - for (const auto &word : *stop) - { - if (!word.empty()) - { - slot->params.antiprompt.push_back(word); - } - } + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); } - - const auto & samplers = data.find("samplers"); - if (samplers != data.end() && samplers->is_array()) { - std::vector sampler_names; - for (const auto & name : *samplers) { - if (name.is_string()) { - sampler_names.emplace_back(name); - } - } - slot->sparams.samplers = common_sampler_types_from_names(sampler_names, false); - } - else - { - slot->sparams.samplers = default_sparams.samplers; - } - + slot.state = SLOT_STATE_STARTED; - if (slot->ctx_sampling != nullptr) - { - common_sampler_free(slot->ctx_sampling); - } - slot->ctx_sampling = common_sampler_init(model, slot->sparams); - //llama_set_rng_seed(ctx, slot->params.seed); - slot->command = LOAD_PROMPT; - - all_slots_are_idle = false; - - LOG_INFO("slot is processing task", { - {"slot_id", slot->id}, - {"task_id", slot->task_id}, - }); - - // LOG("sampling: \n%s\n", llama_sampling_print(slot->sparams).c_str()); + SLT_INF(slot, "%s", "processing task\n"); return true; } void kv_cache_clear() { + SRV_DBG("%s", "clearing KV cache\n"); + // clear the entire KV cache - llama_kv_cache_clear(ctx); + llama_kv_self_clear(ctx); clean_kv_cache = false; } - void update_system_prompt() { - kv_cache_clear(); - system_tokens.clear(); - - if (!system_prompt.empty()) { - system_tokens = common_tokenize(ctx, system_prompt, add_bos_token); - - common_batch_clear(batch); - - for (int i = 0; i < (int)system_tokens.size(); ++i) - { - common_batch_add(batch, system_tokens[i], i, { 0 }, false); - } - - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) - { - const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i)); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - if (llama_decode(ctx, batch_view) != 0) - { - LOG("%s: llama_decode() failed\n", __func__); - return; - } - } - - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) - { - llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); - } - } - - LOG("system prompt updated\n"); - system_need_update = false; - } - - void notify_system_prompt_changed() { - // release all slots - for (llama_client_slot &slot : slots) - { - slot.release(); - } - - system_need_update = true; - } - - void process_system_prompt_data(const json &sys_props) { - system_prompt = sys_props.value("prompt", ""); - name_user = sys_props.value("anti_prompt", ""); - name_assistant = sys_props.value("assistant_name", ""); - - - notify_system_prompt_changed(); - } - - static size_t find_stopping_strings(const std::string &text, const size_t last_token_size, - const stop_type type, llama_client_slot &slot) - { - size_t stop_pos = std::string::npos; - - for (const std::string &word : slot.params.antiprompt) - { - size_t pos; - if (type == STOP_FULL) - { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - pos = text.find(word, from_pos); - } - else - { - pos = find_partial_stop_string(word, text); - } - if (pos != std::string::npos && - (stop_pos == std::string::npos || pos < stop_pos)) - { - if (type == STOP_FULL) - { - slot.stopped_word = true; - slot.stopping_word = word; - slot.has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - bool process_token(completion_token_output &result, llama_client_slot &slot) { + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = common_token_to_piece(ctx, result.tok); + const std::string token_str = result.text_to_send; slot.sampled = result.tok; - // search stop word and delete it slot.generated_text += token_str; + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } slot.has_next_token = true; -/* - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) - { - // we can change penalty_prompt_tokens because it is always created from scratch each request - slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); - } - */ - // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) - { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) - { - // continuation byte: 10xxxxxx - continue; - } - if ((c & 0xE0) == 0xC0) - { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } - else if ((c & 0xF0) == 0xE0) - { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } - else if ((c & 0xF8) == 0xF0) - { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + // search stop word and delete it + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - if (!incomplete) - { - size_t pos = std::min(slot.sent_count, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot); - if (stop_pos != std::string::npos) - { - is_stop_full = true; + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { slot.generated_text.erase( slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); - pos = std::min(slot.sent_count, slot.generated_text.size()); - } - else - { - is_stop_full = false; - stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; } // check if there is any token to predict - if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) - { + if (send_text) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.sent_count += result.text_to_send.size(); + slot.n_sent_text += result.text_to_send.size(); // add the token to slot queue and cache + } else { + result.text_to_send = ""; } - slot.add_token_string(result); - if (slot.params.stream) - { + + slot.add_token(result); + if (slot.params.stream) { send_partial_response(slot, result); } } - if (incomplete) - { + if (incomplete) { slot.has_next_token = true; } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) - { - slot.stopped_limit = true; + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; + + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); } + if (slot.has_new_line) { + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } + } + + // if context shift is disabled, we stop when it reaches the context limit if (slot.n_past >= slot.n_ctx) { slot.truncated = true; - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - LOG_VERBOSE("stopped due to running out of context capacity", {}); + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); } - if (result.tok == llama_vocab_eos(vocab) || llama_vocab_is_eog(vocab, result.tok)) - { - slot.stopped_eos = true; + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; - LOG_VERBOSE("eos token found", {}); + + SLT_DBG(slot, "%s", "stopped by EOS\n"); } - LOG_VERBOSE("next token", { - {"token", result.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"num_tokens_predicted", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }); + const auto n_ctx_train = llama_model_n_ctx_train(model); + + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; // stop prediction + + SLT_WRN(slot, + "n_predict (%d) is set for infinite generation. " + "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", + slot.params.n_predict, n_ctx_train); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); return slot.has_next_token; // continue } - void send_error(task_server& task, const std::string &error) - { - LOG("task %i - error: %s\n", task.id, error.c_str()); - task_result res; - res.id = task.id; - res.multitask_id = task.multitask_id; - res.stop = false; - res.error = true; - res.result_json = { { "content", error } }; - queue_results.send(res); - } + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; - json get_formated_generation(llama_client_slot &slot) - { - std::vector samplers; - samplers.reserve(slot.sparams.samplers.size()); - for (const auto & sampler : slot.sparams.samplers) - { - samplers.emplace_back(common_sampler_type_to_str(sampler)); - } - - return json { - {"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, - {"model", params.model_alias}, - {"seed", slot.params.seed}, - {"temperature", slot.sparams.temp}, - {"dynatemp_range", slot.sparams.dynatemp_range}, - {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"typical_p", slot.sparams.typ_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, - {"frequency_penalty", slot.sparams.penalty_freq}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, - {"n_keep", params.n_keep}, - {"ignore_eos", slot.sparams.ignore_eos}, - {"stream", slot.params.stream}, - // {"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"min_keep", slot.sparams.min_keep}, - {"grammar", slot.sparams.grammar}, - {"samplers", samplers} - }; - } - - void send_partial_response(llama_client_slot &slot, completion_token_output tkn) - { - task_result res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = false; - - res.result_json = json - { - {"content", tkn.text_to_send}, - {"stop", false}, - {"slot_id", slot.id}, - {"multimodal", multimodal} - }; - - if (slot.sparams.n_probs > 0) - { - std::vector probs_output = {}; - const std::vector to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); - size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size()); - size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) - { - probs_output = std::vector(slot.generated_token_probs.begin() + probs_pos, slot.generated_token_probs.begin() + probs_stop_pos); - } - slot.sent_token_probs_index = probs_stop_pos; - res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); - } - - if (slot.oaicompat) - { - res.result_json["oaicompat_token_ctr"] = slot.n_decoded; - res.result_json["model"] = slot.oaicompat_model; - } - - queue_results.send(res); - } - - void send_final_response(llama_client_slot &slot) - { - task_result res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = true; - - res.result_json = json - { - {"content", !slot.params.stream ? slot.generated_text : ""}, - {"slot_id", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.num_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()} - }; - - if (slot.sparams.n_probs > 0) - { - std::vector probs = {}; - if (!slot.params.stream && slot.stopped_word) - { - const std::vector stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); - probs = std::vector(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size()); - } - else - { - probs = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); - } - res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); - } - - if (slot.oaicompat) - { - res.result_json["oaicompat_token_ctr"] = slot.n_decoded; - res.result_json["model"] = slot.oaicompat_model; - } - - queue_results.send(res); - } - - void send_embedding(llama_client_slot &slot, const llama_batch & batch) - { - task_result res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = true; - - const int n_embd = llama_model_n_embd(model); - if (!params.embedding) - { - LOG_WARNING("embedding disabled", { - {"params.embedding", params.embedding}, - }); - res.result_json = json - { - {"embedding", std::vector(n_embd, 0.0f)}, - }; - } - else - { - const float *data = llama_get_embeddings(ctx); - std::vector embd_res(n_embd, 0.0f); - std::vector> embedding; - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - LOG("failed to get embeddings"); - - continue; - } - - // normalize only when there is pooling - // TODO: configurable - if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); - embedding.push_back(embd_res); - } else { - embedding.push_back({ embd, embd + n_embd }); - } - } - - // OAI compat - res.result_json = json - { - {"embedding", embedding[0] }, - }; - } - queue_results.send(res); - } - - void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) - { - task_server task; - task.id = task_id; - task.target_id = 0; - task.data = std::move(data); - task.infill_mode = infill; - task.embedding_mode = embedding; - task.type = TASK_TYPE_COMPLETION; - task.multitask_id = multitask_id; - - // when a completion task's prompt array is not a singleton, we split it into multiple requests - // otherwise, it's a single-prompt task, we actually queue it - // if there's numbers in the prompt array it will be treated as an array of tokens - if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) { - bool numbers = false; - for (const auto& e : task.data.at("prompt")) { - if (e.is_number()) { - numbers = true; + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; break; } } - // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, - // it will completely stall the server. I don't know where the bug for this is. - // - // if there are numbers, it needs to be treated like a single prompt, - // queue_tasks handles a mix of strings and numbers just fine. - if (numbers) { - queue_tasks.post(task); - } else { - split_multiprompt_task(task_id, task); + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_token_to_piece(ctx, cur_p->data[i].id, special), + cur_p->data[i].p + }); } } else { - queue_tasks.post(task); - } - } + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); - void request_cancel(int task_id) - { - task_server task; - task.type = TASK_TYPE_CANCEL; - task.target_id = task_id; - queue_tasks.post(task); - } - - void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) - { - int prompt_count = multiprompt_task.data.at("prompt").size(); - if (prompt_count <= 1) { - send_error(multiprompt_task, "error while handling multiple prompts"); - return; - } - - // generate all the ID for subtask - std::vector subtask_ids(prompt_count); - for (int i = 0; i < prompt_count; i++) - { - subtask_ids[i] = queue_tasks.get_new_id(); - } - - // queue up the multitask so we can track its subtask progression - queue_tasks.add_multitask(multitask_id, subtask_ids); - - // add subtasks - for (int i = 0; i < prompt_count; i++) - { - json subtask_data = multiprompt_task.data; - subtask_data["prompt"] = subtask_data["prompt"][i]; - - // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); - } - } - - void process_single_task(task_server& task) - { - switch (task.type) - { - case TASK_TYPE_COMPLETION: { - llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) - { - // if no slot is available, we defer this task for processing later - LOG_VERBOSE("no slot is available", {{"task_id", task.id}}); - queue_tasks.defer(task); + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; break; } - - if (task.data.contains("system_prompt")) - { - if (!all_slots_are_idle) { - send_error(task, "system prompt can only be updated when all slots are idle"); - break; - } - process_system_prompt_data(task.data["system_prompt"]); - - // reset cache_tokens for all slots - for (llama_client_slot &slot : slots) - { - slot.cache_tokens.clear(); - slot.n_past = 0; - slot.n_past_se = 0; - } - } - - slot->reset(); - - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; - slot->multitask_id = task.multitask_id; - - if (!launch_slot_with_data(slot, task.data)) - { - // send error result - send_error(task, "internal_error"); - break; - } - } break; - case TASK_TYPE_CANCEL: { // release slot linked with the task id - for (auto & slot : slots) - { - if (slot.task_id == task.target_id) - { - slot.release(); - break; - } - } - } break; - case TASK_TYPE_NEXT_RESPONSE: { - // do nothing - } break; - } - } - - void on_finish_multitask(task_multi& multitask) - { - // all subtasks done == multitask is done - task_result result; - result.id = multitask.id; - result.stop = true; - result.error = false; - - // collect json results into one json result - std::vector result_jsons; - for (auto& subres : multitask.results) - { - result_jsons.push_back(subres.result_json); - result.error = result.error && subres.error; - } - result.result_json = json{ { "results", result_jsons } }; - queue_results.send(result); - } - - bool update_slots() { - if (system_need_update) - { - LOG_INFO("updating system prompt", {}); - update_system_prompt(); - } - - common_batch_clear(batch); - - if (all_slots_are_idle) - { - if (system_prompt.empty() && clean_kv_cache) - { - LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {}); - kv_cache_clear(); } - return true; - } - LOG_VERBOSE("posting NEXT_RESPONSE", {}); - task_server task; - task.type = TASK_TYPE_NEXT_RESPONSE; - task.target_id = -1; - queue_tasks.post(task); - - for (llama_client_slot &slot : slots) - { - if (slot.ga_n == 1) - { - if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) - { - // this check is redundant (for good) - // we should never get here, because generation should already stopped in process_token() - - // START LOCALAI changes - // Temporary disable context-shifting as it can lead to infinite loops (issue: https://github.com/ggerganov/llama.cpp/issues/3969) - // See: https://github.com/mudler/LocalAI/issues/1333 - // Context is exhausted, release the slot - slot.release(); - send_final_response(slot); - slot.has_next_token = false; - LOG_ERROR("context is exhausted, release the slot", {}); - - continue; - // END LOCALAI changes - } - } - } - - // decode any currently ongoing sequences - LOG_VERBOSE("decoding ongoing sequences", {}); - for (auto & slot : slots) - { - // release the slot - if (slot.command == RELEASE) - { - slot.state = IDLE; - slot.command = NONE; - slot.t_last_used = ggml_time_us(); - - LOG_INFO("slot released", { - {"slot_id", slot.id}, - {"task_id", slot.task_id}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated} + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_token_to_piece(ctx, cur[i].id, special), + cur[i].p }); - queue_tasks.notify_slot_changed(); + } + } + } + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.id_task, error, type); + } + + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); + + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + + queue_results.send(std::move(res)); + } + + // if multimodal is enabled, send an error and return false + bool ensure_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { + auto res = std::make_unique(); + + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); + } + + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; + + res->index = slot.index; + res->content = std::move(slot.generated_text); + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = slot.prompt_tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_chat_format = slot.params.oaicompat_chat_format; + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + } + } + + res->generation_params = slot.params; // copy the parameters + + queue_results.send(std::move(res)); + } + + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; + + const int n_embd = llama_model_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { continue; } - if (slot.state == IDLE) - { + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->embedding.push_back(std::vector(n_embd, 0.0f)); + continue; + } + + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } + } + + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); + } + + // + // Functions to create new task(s) and receive result(s) + // + + void cancel_tasks(const std::unordered_set & id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(std::move(cancel_tasks), true); + } + + // receive the results from task(s) + void receive_multi_results( + const std::unordered_set & id_tasks, + const std::function&)> & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + std::vector results(id_tasks.size()); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + i--; // retry + continue; + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); + } + result_handler(results); + } + + // receive the results from task(s), in stream mode + void receive_cmpl_results_stream( + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + size_t n_finished = 0; + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + continue; // retry + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (!result_handler(result)) { + cancel_tasks(id_tasks); + break; + } + + if (result->is_stop()) { + if (++n_finished == id_tasks.size()) { + break; + } + } + } + } + + // + // Functions to process the task + // + + void process_single_task(server_task && task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + { + const int id_slot = task.id_selected_slot; + + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (!launch_slot_with_task(*slot, std::move(task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot : slots) { + json slot_data = slot.to_json(); + + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; + + res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx); + res->kv_cache_used_cells = llama_kv_self_used_cells(ctx); + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + if (!ensure_no_mtmd(task.id)) { + break; + } + + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + llama_tokens tokens; + tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.clear(); // KV may already been invalidated? + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + tokens.resize(token_count); + slot->cache_tokens.clear(); + slot->cache_tokens.insert(tokens); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_self_seq_rm(ctx, slot->id, -1, -1); + slot->cache_tokens.clear(); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; + + } + } + + void update_slots() { + // check if all slots are idle + { + bool all_idle = true; + + for (auto & slot : slots) { + if (slot.is_processing()) { + all_idle = false; + break; + } + } + + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + if (clean_kv_cache) { + kv_cache_clear(); + } + + return; + } + } + + { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot & slot : slots) { + if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } + + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + + // Shift context + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); + + llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); + llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + + if (slot.params.cache_prompt) { + llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; + } + + new_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.clear(); + slot.cache_tokens.insert(new_tokens); + } + + slot.n_past -= n_discard; + + slot.truncated = true; + } + } + + // start populating the batch for this iteration + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + }; + + // frist, add sampled tokens from any ongoing sequences + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { continue; } slot.i_batch = batch.n_tokens; - const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); - // TODO: we always have to take into account the "system_tokens" - // this is not great and needs to be improved somehow - common_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); slot.n_past += 1; + + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(slot.sampled); + } + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); } // process in chunks of params.n_batch - int32_t n_batch = params.n_batch; + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); - // assign workload to the slots - if (params.cont_batching || batch.n_tokens == 0) - { - for (auto & slot : slots) - { - const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()); - - // empty prompt passed -> release the slot and send empty response - // note: infill mode allows empty prompt - if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill) - { - slot.release(); - slot.print_timings(); - send_final_response(slot); - continue; + // next, batch any pending prompts without exceeding n_batch + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } } - // need process the prompt - if (slot.state == IDLE && slot.command == LOAD_PROMPT) - { - slot.state = PROCESSING; - slot.command = NONE; - std::vector prompt_tokens; - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_genereration = 0; + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + auto & prompt_tokens = slot.prompt_tokens; - if (slot.infill) - { - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) - { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - auto prefix_tokens = tokenize(slot.data, slot.params.input_prefix, false); - auto suffix_tokens = tokenize(slot.data, slot.params.input_suffix, false); - - const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0][0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - - // Create llama_tokens vectors for the special tokens - llama_tokens fim_pre_tokens; - fim_pre_tokens.push_back(llama_vocab_fim_pre(vocab)); - llama_tokens bos_tokens; - bos_tokens.push_back(llama_vocab_bos(vocab)); - llama_tokens fim_suf_tokens; - fim_suf_tokens.push_back(llama_vocab_fim_suf(vocab)); - llama_tokens fim_mid_tokens; - fim_mid_tokens.push_back(llama_vocab_fim_mid(vocab)); - - // Create server_tokens objects - server_tokens fim_pre_token(fim_pre_tokens, mctx != nullptr); - server_tokens bos_token(bos_tokens, mctx != nullptr); - server_tokens fim_suf_token(fim_suf_tokens, mctx != nullptr); - server_tokens fim_mid_token(fim_mid_tokens, mctx != nullptr); - - // Insert tokens in the correct order - prefix_tokens.insert(prefix_tokens.begin(), fim_pre_token); - prefix_tokens.insert(prefix_tokens.begin(), bos_token); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), fim_suf_token); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - prefix_tokens.push_back(fim_mid_token); - prompt_tokens = prefix_tokens; - } - else - { - prompt_tokens = tokenize(slot.data, slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt - } - - slot.num_prompt_tokens = prompt_tokens.size(); - - if (slot.params.n_keep < 0) - { - slot.params.n_keep = slot.num_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it - if (slot.num_prompt_tokens >= slot.n_ctx) - { - const int n_left = slot.n_ctx - slot.params.n_keep; - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - slot.truncated = true; - - // Convert new_tokens to server_tokens - std::vector new_prompt_tokens; - server_tokens new_server_tokens(new_tokens, mctx != nullptr); - new_prompt_tokens.push_back(std::move(new_server_tokens)); - prompt_tokens = std::move(new_prompt_tokens); - - slot.num_prompt_tokens = prompt_tokens.size(); - GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); - } - - if (!slot.params.cache_prompt) - { - common_sampler_reset(slot.ctx_sampling); + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - slot.num_prompt_tokens_processed = slot.num_prompt_tokens; - } - else - { - // push the prompt into the sampling context (do not apply grammar) - for (auto &token : prompt_tokens) - { - // Convert server_tokens to llama_token for sampling - llama_token tok = token[0]; // Get first token - common_sampler_accept(slot.ctx_sampling, tok, false); - } + slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; - // Convert server_tokens to llama_tokens for comparison - std::vector prompt_llama_tokens; - for (const auto &token : prompt_tokens) { - prompt_llama_tokens.push_back(token[0]); - } - slot.n_past = common_part(slot.cache_tokens, prompt_llama_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); - // the last token of the cache is not in the KV cache until the next call to llama_decode - // (it was sampled, pushed into the "cache_tokens", but not yet put in the context) - if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size()) - { - slot.n_past -= 1; - } - - slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; - - if (slot.ga_n != 1) - { - int ga_i = 0; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - int32_t slot_npast = 0; - for (int k = 0; k < slot.n_past; ++k) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - slot_npast++; + // print prompt tokens (for debugging) + /*if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - slot.n_past_se = slot_npast; - slot.ga_i = ga_i; + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + }*/ + + // empty prompt passed -> release the slot and send empty response + if (prompt_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + + slot.release(); + slot.print_timings(); + send_final_response(slot); + continue; } - LOG_INFO("slot progression", { - { "slot_id", slot.id }, - { "task_id", slot.task_id }, - { "n_past", slot.n_past }, - { "num_prompt_tokens_processed", slot.num_prompt_tokens_processed } - }); + if (slot.is_non_causal()) { + if (slot.n_prompt_tokens > n_ubatch) { + slot.release(); + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + continue; + } + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); + continue; + } + } else { + if (!params_base.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.n_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + const int n_left = slot.n_ctx - slot.params.n_keep; + + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + + const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); + llama_tokens new_tokens( + curr_tokens.begin(), + curr_tokens.begin() + slot.params.n_keep); + + new_tokens.insert( + new_tokens.end(), + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + curr_tokens.end()); + + prompt_tokens.clear(); + prompt_tokens.insert(new_tokens); + + slot.truncated = true; + slot.n_prompt_tokens = prompt_tokens.size(); + + SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); + + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + } + + if (slot.params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt + + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); + + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c); + llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + } + } + } + + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + // we have to evaluate at least 1 token to generate logits. + SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); + + slot.n_past--; + } + + slot.n_prompt_tokens_processed = 0; } - // Convert server_tokens to llama_tokens for cache - std::vector cache_llama_tokens; - for (const auto &token : prompt_tokens) { - cache_llama_tokens.push_back(token[0]); - } - slot.cache_tokens = cache_llama_tokens; - - if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) - { - // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", { - { "slot_id", slot.id }, - { "task_id", slot.task_id } - }); - slot.n_past--; - if (slot.ga_i > 0) - { - slot.n_past_se--; + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.is_non_causal()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + continue; } } - int p0 = (int) system_tokens.size() + slot.n_past; - LOG_INFO("kv cache rm [p0, end)", { - { "slot_id", slot.id }, - { "task_id", slot.task_id }, - { "p0", p0 } - }); - llama_kv_cache_seq_rm(ctx, slot.id, p0, -1); + // keep only the common part + if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_self_seq_rm(ctx, slot.id, -1, -1); + // there is no common part left + slot.n_past = 0; + } - // process the prefix of first image - std::vector prefix_tokens = prompt_tokens; + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); // check if we should process the image if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); int32_t n_pos = new_n_past - slot.n_past; + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); slot.release(); - LOG_ERR("failed to process image, res = %d\n", res); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); continue; } + if (slot.params.cache_prompt) { + const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); + slot.cache_tokens.push_back(chunk.get()); // copy + } slot.n_past += n_pos; - // slot.n_prompt_tokens_processed += n_pos; + slot.n_prompt_tokens_processed += n_pos; } - LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, - {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, - }); + // add prompt tokens for processing in the current batch + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - int32_t ga_i = slot.ga_i; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(cur_tok); + } - for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) - { - if (slot.ga_n != 1) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; + slot.n_prompt_tokens_processed++; + slot.n_past++; + } + + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_DONE_PROMPT; + + GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); + + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); } } - common_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); - slot_npast++; - } - // extract the logits only for the last token - if (batch.n_tokens > 0) - { + // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; - } - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + } + } + + if (batch.n_tokens >= n_batch) { + break; } } } - if (batch.n_tokens == 0) - { - all_slots_are_idle = true; - return true; + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + return; } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) - { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - for (auto & slot : slots) - { - if (slot.ga_n != 1) - { - // context extension via Self-Extend - while (slot.n_past_se >= slot.ga_i + slot.ga_w) - { - const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; - const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); - const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + } - LOG("\n"); - LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); - LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); - - slot.n_past_se -= bd; - - slot.ga_i += slot.ga_w / slot.ga_n; - - LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); - } - slot.n_past_se += n_tokens; - } - } - - llama_batch batch_view = - { + llama_batch batch_view = { n_tokens, batch.token + i, nullptr, @@ -1851,135 +3319,237 @@ struct llama_server_context batch.logits + i, }; - const int ret = llama_decode(ctx, batch_view); + int ret = 0; - if (ret != 0) - { - if (n_batch == 1 || ret < 0) - { + if (params_base.embedding || params_base.reranking) { + ret = llama_encode(ctx, batch_view); + } else { + ret = llama_decode(ctx, batch_view); + } + + metrics.on_decoded(slots); + + if (ret != 0) { + if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size - LOG("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); - return false; + SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + for (auto & slot : slots) { + slot.release(); + send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + } + break; // break loop of n_batch } - LOG("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); - // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; i -= n_batch; - continue; + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + + continue; // continue loop of n_batch } - for (auto & slot : slots) - { - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) - { - continue; + for (auto & slot : slots) { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + continue; // continue loop of slots } - // prompt evaluated for embedding - if (slot.embedding) - { - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots } - completion_token_output result; - const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i); + const int tok_idx = slot.i_batch - i; - common_sampler_accept(slot.ctx_sampling, id, true); + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; - if (slot.n_decoded == 1) - { - slot.t_start_genereration = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); } - result.tok = id; - const auto * cur_p = common_sampler_get_candidates(slot.ctx_sampling); + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; - for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { - result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); } - if (!process_token(result, slot)) - { + if (!process_token(result, slot)) { + // release slot because of stop condition slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + continue; + } + } + + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; } - slot.i_batch = -1; + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // keep track of total number of tokens generated in the draft + slot.n_draft_total += draft.size(); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + + continue; + } + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + // update how many tokens out of draft was accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); + + llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); } } - LOG_VERBOSE("slots updated", {}); - return true; + SRV_DBG("%s", "run slots completed\n"); } - void run_on_all_tasks_finished() { - update_slots(); + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; } }; -/* llama.cpp completion api semantics */ -static json format_partial_response( - llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector &probs -) { - json res = json - { - {"content", content }, - {"stop", false}, - {"slot_id", slot->id }, - {"multimodal", llama.multimodal } - }; - - if (slot->sparams.n_probs > 0) - { - res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); - } - - return res; -} - -struct token_translator -{ - llama_context * ctx; - std::string operator()(llama_token tok) const { return common_token_to_piece(ctx, tok); } - std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); } -}; - -static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot *slot) -{ - auto & gtps = slot->generated_token_probs; - auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; - const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); - if (slot->generated_text.capacity() < slot->generated_text.size() + len) - { - slot->generated_text.reserve(slot->generated_text.size() + len); - } - for (const completion_token_output & cto : gtps) - { - slot->generated_text += translator(cto); - } -} std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; inline void signal_handler(int signal) { - exit(1); + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); } + ///////////////////////////////// //////////////////////////////// //////// LOCALAI code starts below here @@ -1991,29 +3561,56 @@ bool loaded_model; // TODO: add a mutex for this, but happens only once loading // The class has a llama instance that is shared across all RPCs llama_server_context llama; -static void start_llama_server() { +static void start_llama_server(server_context& ctx_server) { // Wait for model to be loaded first while (!loaded_model) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } - llama.queue_tasks.on_new_task(std::bind( - &llama_server_context::process_single_task, &llama, std::placeholders::_1)); - llama.queue_tasks.on_finish_multitask(std::bind( - &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1)); - llama.queue_tasks.on_all_tasks_finished(std::bind( - &llama_server_context::run_on_all_tasks_finished, &llama)); - llama.queue_results.on_multitask_update(std::bind( - &llama_server_queue::update_multitask, - &llama.queue_tasks, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3 - )); - llama.queue_tasks.start_loop(); + ctx_server.init(); + state.store(SERVER_STATE_READY); + + LOG_INF("%s: model loaded\n", __func__); + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(ctx_server.chat_templates.get()), + common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str()); + + ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { + ctx_server.process_single_task(std::move(task)); + }); + + ctx_server.queue_tasks.on_update_slots([&ctx_server]() { + ctx_server.update_slots(); + }); + + shutdown_handler = [&](int) { + // this will unblock start_loop() + ctx_server.queue_tasks.terminate(); + }; + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port); + + // this call blocks the main thread until queue_tasks.terminate() is called + ctx_server.queue_tasks.start_loop(); } -json parse_options(bool streaming, const backend::PredictOptions* predict, llama_server_context &llama) +json parse_options(bool streaming, const backend::PredictOptions* predict) { // Create now a json data from the prediction options instead @@ -2209,13 +3806,13 @@ static void params_parse(const backend::ModelOptions* request, if (request->grammartriggers_size() > 0) { LOG_INFO("configuring grammar triggers", {}); - llama.grammar_lazy = true; + params.sampling.grammar_lazy = true; for (int i = 0; i < request->grammartriggers_size(); i++) { common_grammar_trigger trigger; trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; trigger.value = request->grammartriggers(i).word(); // trigger.at_start = request->grammartriggers(i).at_start(); - llama.grammar_triggers.push_back(trigger); + params.sampling.grammar_triggers.push_back(trigger); LOG_INFO("grammar trigger", { { "word", trigger.value }, }); @@ -2226,189 +3823,475 @@ static void params_parse(const backend::ModelOptions* request, // GRPC Server start class BackendServiceImpl final : public backend::Backend::Service { +private: + server_context& ctx_server; + bool loaded_model = false; + public: - grpc::Status Health(ServerContext* context, const backend::HealthMessage* request, backend::Reply* reply) { - // Implement Health RPC - reply->set_message("OK"); - return Status::OK; - } + BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {} - grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) { - // Implement LoadModel RPC - common_params params; - params_parse(request, params, llama); - - llama_backend_init(); - llama_numa_init(params.numa); - - // load the model - if (!llama.load_model(params)) - { - result->set_message("Failed loading model"); - result->set_success(false); - return Status::CANCELLED; + grpc::Status Health(ServerContext* context, const backend::HealthMessage* request, backend::Reply* reply) { + // Implement Health RPC + reply->set_message("OK"); + return Status::OK; } - llama.initialize(); - result->set_message("Loading succeeded"); - result->set_success(true); - loaded_model = true; - return Status::OK; - } - grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { - json data = parse_options(true, request, llama); - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); - while (true) - { - task_result result = llama.queue_results.recv(task_id); - if (!result.error) { - const std::string str = - "data: " + - result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); + + grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) { + // Implement LoadModel RPC + common_params params; + params_parse(request, params); + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + if (!ctx_server.load_model(params)) { + result->set_message("Failed loading model"); + result->set_success(false); + return Status::CANCELLED; + } + + //ctx_server.init(); + result->set_message("Loading succeeded"); + result->set_success(true); + loaded_model = true; + return Status::OK; + } + + grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { + json data = parse_options(true, request); + + + //Raise error if embeddings is set to true + if (ctx_server.params_base.embedding) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode"); + } + + + auto completion_id = gen_chatcmplid(); + std::unordered_set task_ids; + try { + std::vector tasks; + + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process files + mtmd::bitmaps bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + { + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + } + + // process prompt + std::vector inputs; + if (!prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + } + + if (has_mtmd) { + // multimodal + std::string prompt_str = prompt.get(); + mtmd_input_text inp_txt = { + prompt_str.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + + server_tokens tmp(chunks, true); + inputs.push_back(std::move(tmp)); + } else { + // non-multimodal version + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (auto & p : tokenized_prompts) { + auto tmp = server_tokens(p, ctx_server.mctx != nullptr); + inputs.push_back(std::move(tmp)); + } + } + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_COMPLETION; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } catch (const std::exception & e) { + res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { + json res_json = result->to_json(); + if (res_json.is_array()) { + for (const auto & res : res_json) { + std::string completion_text = res.value("content", ""); + + backend::Reply reply; + reply.set_message(completion_text); + int32_t tokens_predicted = res.value("tokens_predicted", 0); + reply->set_tokens(tokens_predicted); + int32_t tokens_evaluated = res.value("tokens_evaluated", 0); + reply->set_prompt_tokens(tokens_evaluated); + + if (res.contains("timings")) { + double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0); + reply.set_timing_prompt_processing(timing_prompt_processing); + double timing_token_generation = res.at("timings").value("predicted_ms", 0.0); + reply.set_timing_token_generation(timing_token_generation); + } + + // Log Request Correlation Id + LOG_VERBOSE("correlation:", { + { "id", data["correlation_id"] } + }); + + // Send the reply + writer->Write(reply); + } + } else { + std::string completion_text = res_json.value("content", ""); backend::Reply reply; - // print it - std::string completion_text = result.result_json.value("content", ""); - reply.set_message(completion_text); - int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0); + int32_t tokens_predicted = res_json.value("tokens_predicted", 0); reply.set_tokens(tokens_predicted); - int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); + int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0); reply.set_prompt_tokens(tokens_evaluated); - if (result.result_json.contains("timings")) { - double timing_prompt_processing = result.result_json.at("timings").value("prompt_ms", 0.0); + if (res_json.contains("timings")) { + double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0); reply.set_timing_prompt_processing(timing_prompt_processing); - double timing_token_generation = result.result_json.at("timings").value("predicted_ms", 0.0); + double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0); reply.set_timing_token_generation(timing_token_generation); } - + // Log Request Correlation Id LOG_VERBOSE("correlation:", { { "id", data["correlation_id"] } }); // Send the reply - writer->Write(reply); - - if (result.stop) { - break; - } - } else { - break; + writer->Write(reply); + } - } + return true; + }, [&](const json & error_data) { + backend::Reply reply; + reply.set_message(error_data.value("content", "")); + writer->Write(reply); + return true; + }, [&writer]() { + // note: do not use req.is_connection_closed here because req is already destroyed + return !writer->IsWritePossible(); + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); return grpc::Status::OK; } - grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { - json data = parse_options(false, request, llama); - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); - std::string completion_text; - task_result result = llama.queue_results.recv(task_id); - if (!result.error && result.stop) { - - // Log Request Correlation Id - LOG_VERBOSE("correlation:", { - { "id", data["correlation_id"] } - }); + + //Raise error if embeddings is set to true + if (ctx_server.params_base.embedding) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode"); + } - completion_text = result.result_json.value("content", ""); - int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0); - int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); - reply->set_prompt_tokens(tokens_evaluated); - reply->set_tokens(tokens_predicted); - reply->set_message(completion_text); + auto completion_id = gen_chatcmplid(); + std::unordered_set task_ids; + try { + std::vector tasks; - if (result.result_json.contains("timings")) { - double timing_prompt_processing = result.result_json.at("timings").value("prompt_ms", 0.0); - reply->set_timing_prompt_processing(timing_prompt_processing); - double timing_token_generation = result.result_json.at("timings").value("predicted_ms", 0.0); - reply->set_timing_token_generation(timing_token_generation); + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process files + mtmd::bitmaps bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + { + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } } + + // process prompt + std::vector inputs; + if (!prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + } + + if (has_mtmd) { + // multimodal + std::string prompt_str = prompt.get(); + mtmd_input_text inp_txt = { + prompt_str.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + + server_tokens tmp(chunks, true); + inputs.push_back(std::move(tmp)); + } else { + // non-multimodal version + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (auto & p : tokenized_prompts) { + auto tmp = server_tokens(p, ctx_server.mctx != nullptr); + inputs.push_back(std::move(tmp)); + } + } + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_COMPLETION; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } catch (const std::exception & e) { + res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; } - else - { - return grpc::Status::OK; + + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + reply->set_message(results[0]->to_json()); + } else { + // multiple results (multitask) + json arr = json::array(); + for (auto & res : results) { + arr.push_back(res->to_json()); + } + reply->set_message(arr); } + }, [&](const json & error_data) { + reply->set_message(error_data.value("content", "")); + }, [&]() { + return false; + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); return grpc::Status::OK; } - /// https://github.com/ggerganov/llama.cpp/blob/aa2341298924ac89778252015efcb792f2df1e20/examples/server/server.cpp#L2969 grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) { - json data = parse_options(false, request, llama); - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1); - // get the result - task_result result = llama.queue_results.recv(task_id); - //std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl; - llama.queue_results.remove_waiting_task_id(task_id); - if (!result.error && result.stop) { - std::vector embeddings = result.result_json.value("embedding", std::vector()); - // loop the vector and set the embeddings results - for (int i = 0; i < embeddings.size(); i++) { - embeddingResult->add_embeddings(embeddings[i]); + + json body = parse_options(false, request); + + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Pooling type 'none' is not OAI compatible. Please use a different pooling type"); + } + + // for the shape of input/content, see tokenize_input_prompts() + json prompt = body.at("prompt"); + + + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Input content cannot be empty"); } } - else + + // create and queue the task + json responses = json::array(); + bool error = false; + std::unordered_set task_ids; { - return grpc::Status::OK; + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr); + + // OAI-compat + task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING; + + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } + + // get the result + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }, req.is_connection_closed); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + + if (error) { + return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); + } + + std::vector embeddings = responses[0].value("embedding", std::vector()); + // loop the vector and set the embeddings results + for (int i = 0; i < embeddings.size(); i++) { + embeddingResult->add_embeddings(embeddings[i]); } return grpc::Status::OK; } - grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response){ - json data = parse_options(false, request, llama); + grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) { + json body = parse_options(false, request); - std::vector tokens = llama.tokenize(data, data["prompt"],false); + json tokens_response = json::array(); + if (body.count("prompt") != 0) { + const bool add_special = json_value(body, "add_special", false); + const bool with_pieces = json_value(body, "with_pieces", false); - for (int i=0 ; i< tokens.size(); i++){ - response->add_tokens(tokens[i].llama_token); - } + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true); + + + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + response->add_tokens(token); + } + } return grpc::Status::OK; } grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) { - llama_client_slot* active_slot = llama.get_active_slot(); - if (active_slot != nullptr) { - // Calculate the tokens per second using existing logic - double tokens_per_second = 1e3 / active_slot->t_token_generation * active_slot->n_decoded; +// request slots data using task queue + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } - // Populate the response with metrics - response->set_slot_id(active_slot->id); - response->set_prompt_json_for_slot(active_slot->prompt.dump()); - response->set_tokens_per_second(tokens_per_second); - response->set_tokens_generated(active_slot->n_decoded); - response->set_prompt_tokens_processed(active_slot->num_prompt_tokens_processed); - } else { + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { // Handle case when no active slot exists response->set_slot_id(0); response->set_prompt_json_for_slot(""); response->set_tokens_per_second(0); response->set_tokens_generated(0); response->set_prompt_tokens_processed(0); + return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); + + // Populate the response with metrics + response->set_slot_id(0); + response->set_prompt_json_for_slot(""); + response->set_tokens_per_second(res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.); + response->set_tokens_generated(res_metrics->n_tokens_predicted_total); + response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total); + + return grpc::Status::OK; - } + } }; -void RunServer(const std::string& server_address) { - BackendServiceImpl service; + +void RunServer(const std::string& server_address, server_context& ctx_server) { + BackendServiceImpl service(ctx_server); ServerBuilder builder; builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); @@ -2424,20 +4307,6 @@ void RunServer(const std::string& server_address) { int main(int argc, char** argv) { std::string server_address("localhost:50051"); -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif - // Define long and short options struct option long_options[] = { {"addr", required_argument, nullptr, 'a'}, @@ -2457,21 +4326,24 @@ int main(int argc, char** argv) { return 1; } } + + server_context ctx_server; // run the HTTP server in a thread - see comment below std::thread t([&]() - { - RunServer(server_address); - return 0; - }); + { + RunServer(server_address, ctx_server); + return 0; + }); //); - start_llama_server(); - std::cout << "stopping" << std::endl; + start_llama_server(ctx_server); + std::cout << "stopping" << std::endl; + + clean_up(); t.join(); - llama_backend_free(); - return 0; + return 0; } diff --git a/backend/cpp/llama/utils.hpp b/backend/cpp/llama/utils.hpp index c466d356..58af0991 100644 --- a/backend/cpp/llama/utils.hpp +++ b/backend/cpp/llama/utils.hpp @@ -1,490 +1,90 @@ -// https://github.com/ggerganov/llama.cpp/blob/master/tools/server/utils.hpp - #pragma once +#include "common.h" +#include "log.h" +#include "llama.h" +#include "arg.h" // common_remote_get_content +#include "base64.hpp" +#include "mtmd.h" + + + +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" +#include "chat.h" + +#include +#include #include #include -#include -#include -#include -#include +#include +#include -#include "json.hpp" +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" -#include "../mtmd/clip.h" +using json = nlohmann::ordered_json; -using json = nlohmann::json; +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -extern bool server_verbose; +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#ifndef SERVER_VERBOSE -#define SERVER_VERBOSE 1 -#endif +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#if SERVER_VERBOSE != 1 -#define LOG_VERBOSE(MSG, ...) -#else -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - if (server_verbose) \ - { \ - server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) - -// -// parallel -// - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed -}; - -enum task_type { - TASK_TYPE_COMPLETION, - TASK_TYPE_CANCEL, - TASK_TYPE_NEXT_RESPONSE -}; - -struct task_server { - int id = -1; // to be filled by llama_server_queue - int target_id; - task_type type; - json data; - bool infill_mode = false; - bool embedding_mode = false; - int multitask_id = -1; -}; - -struct task_result { - int id; - int multitask_id = -1; - bool stop; - bool error; - json result_json; -}; - -struct task_multi { - int id; - std::set subtasks_remaining{}; - std::vector results{}; -}; - -// TODO: can become bool if we can't find use of more states -enum slot_state -{ - IDLE, - PROCESSING, -}; - -enum slot_command -{ - NONE, - LOAD_PROMPT, - RELEASE, -}; - -struct slot_params -{ - bool stream = true; - bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt - - uint32_t seed = -1; // RNG seed - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_predict = -1; // new tokens to predict - - std::vector antiprompt; - - json input_prefix; - json input_suffix; -}; - -struct slot_image -{ - int32_t id; - - bool request_encode_image = false; - float * image_embedding = nullptr; - int32_t image_tokens = 0; - - clip_image_u8 * img_data; - - std::string prefix_prompt; // before of this image -}; - -// completion token output with probabilities -struct completion_token_output -{ - struct token_prob - { - llama_token tok; - float prob; - }; - - std::vector probs; - llama_token tok; - std::string text_to_send; -}; - -static inline void server_log(const char *level, const char *function, int line, - const char *message, const nlohmann::ordered_json &extra) -{ - nlohmann::ordered_json log - { - {"timestamp", time(nullptr)}, - {"level", level}, - {"function", function}, - {"line", line}, - {"message", message}, - }; - - if (!extra.empty()) - { - log.merge_patch(extra); - } - - const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); - printf("%.*s\n", (int)str.size(), str.data()); - fflush(stdout); -} - -// -// server utils -// +using raw_buffer = std::vector; template -static T json_value(const json &body, const std::string &key, const T &default_value) -{ +static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value - return body.contains(key) && !body.at(key).is_null() - ? body.value(key, default_value) - : default_value; + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); + return default_value; + } + } else { + return default_value; + } } -inline std::string format_chatml(std::vector messages) -{ - std::ostringstream chatml_msgs; +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); - for (auto it = messages.begin(); it != messages.end(); ++it) { - chatml_msgs << "<|im_start|>" - << json_value(*it, "role", std::string("user")) << '\n'; - chatml_msgs << json_value(*it, "content", std::string("")) - << "<|im_end|>\n"; - } +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; - chatml_msgs << "<|im_start|>assistant" << '\n'; - - return chatml_msgs.str(); -} - -// -// work queue utils -// - -struct llama_server_queue { - int id = 0; - std::mutex mutex_tasks; - // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; - std::vector queue_multitasks; - std::condition_variable condition_tasks; - // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_all_task_finished; - - // Add a new task to the end of the queue - int post(task_server task) { - std::unique_lock lock(mutex_tasks); - if (task.id == -1) { - task.id = id++; - } - queue_tasks.push_back(std::move(task)); - condition_tasks.notify_one(); - return task.id; - } - - // Add a new task, but defer until one slot is available - void defer(task_server task) { - std::unique_lock lock(mutex_tasks); - queue_tasks_deferred.push_back(std::move(task)); - } - - // Get the next id for creating anew task - int get_new_id() { - std::unique_lock lock(mutex_tasks); - return id++; - } - - // Register function to process a new task - void on_new_task(std::function callback) { - callback_new_task = callback; - } - - // Register function to process a multitask - void on_finish_multitask(std::function callback) { - callback_finish_multitask = callback; - } - - // Register the function to be called when the batch of tasks is finished - void on_all_tasks_finished(std::function callback) { - callback_all_task_finished = callback; - } - - // Call when the state of one slot is changed - void notify_slot_changed() { - // move deferred tasks back to main loop - std::unique_lock lock(mutex_tasks); - for (auto & task : queue_tasks_deferred) { - queue_tasks.push_back(std::move(task)); - } - queue_tasks_deferred.clear(); - } - - // Start the main loop. This call is blocking - [[noreturn]] - void start_loop() { - while (true) { - // new task arrived - LOG_VERBOSE("have new task", {}); - { - while (true) - { - std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - task_server task = queue_tasks.front(); - queue_tasks.erase(queue_tasks.begin()); - lock.unlock(); - LOG_VERBOSE("callback_new_task", {}); - callback_new_task(task); - } - LOG_VERBOSE("callback_all_task_finished", {}); - // process and update all the multitasks - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) - { - if (queue_iterator->subtasks_remaining.empty()) - { - // all subtasks done == multitask is done - task_multi current_multitask = *queue_iterator; - callback_finish_multitask(current_multitask); - // remove this multitask - queue_iterator = queue_multitasks.erase(queue_iterator); - } - else - { - ++queue_iterator; - } - } - // all tasks in the current loop is finished - callback_all_task_finished(); - } - LOG_VERBOSE("wait for new task", {}); - // wait for new task - { - std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) { - condition_tasks.wait(lock, [&]{ - return !queue_tasks.empty(); - }); - } - } + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} + server_grammar_trigger(const json & in) { + value.type = (common_grammar_trigger_type) in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token) in.at("token").get(); } } - // - // functions to manage multitasks - // - - // add a multitask by specifying the id of all subtask (subtask is a task_server) - void add_multitask(int multitask_id, std::vector& sub_ids) - { - std::lock_guard lock(mutex_tasks); - task_multi multi; - multi.id = multitask_id; - std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - } - - // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int multitask_id, int subtask_id, task_result& result) - { - std::lock_guard lock(mutex_tasks); - for (auto& multitask : queue_multitasks) - { - if (multitask.id == multitask_id) - { - multitask.subtasks_remaining.erase(subtask_id); - multitask.results.push_back(result); - } + json to_json() const { + json out { + {"type", (int) value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int) value.token; } + return out; } }; -struct llama_server_response { - typedef std::function callback_multitask_t; - callback_multitask_t callback_update_multitask; - // for keeping track of all tasks waiting for the result - std::set waiting_task_ids; - // the main result queue - std::vector queue_results; - std::mutex mutex_results; - std::condition_variable condition_results; - - void add_waiting_task_id(int task_id) { - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(task_id); - } - - void remove_waiting_task_id(int task_id) { - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(task_id); - } - - // This function blocks the thread until there is a response for this task_id - task_result recv(int task_id) { - while (true) - { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - return !queue_results.empty(); - }); - LOG_VERBOSE("condition_results unblock", {}); - - for (int i = 0; i < (int) queue_results.size(); i++) - { - if (queue_results[i].id == task_id) - { - assert(queue_results[i].multitask_id == -1); - task_result res = queue_results[i]; - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here - } - - // Register the function to update multitask - void on_multitask_update(callback_multitask_t callback) { - callback_update_multitask = callback; - } - - // Send a new result to a waiting task_id - void send(task_result result) { - std::unique_lock lock(mutex_results); - LOG_VERBOSE("send new result", {}); - for (auto& task_id : waiting_task_ids) { - // LOG_TEE("waiting task id %i \n", task_id); - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (result.multitask_id == task_id) - { - LOG_VERBOSE("callback_update_multitask", {}); - callback_update_multitask(task_id, result.id, result); - continue; - } - - if (result.id == task_id) - { - LOG_VERBOSE("queue_results.push_back", {}); - queue_results.push_back(result); - condition_results.notify_one(); - return; - } - } - } -}; - -// -// base64 utils (TODO: move to common in the future) -// - -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -static inline bool is_base64(uint8_t c) -{ - return (isalnum(c) || (c == '+') || (c == '/')); -} - -static inline std::vector base64_decode(const std::string & encoded_string) -{ - int i = 0; - int j = 0; - int in_ = 0; - - int in_len = encoded_string.size(); - - uint8_t char_array_4[4]; - uint8_t char_array_3[3]; - - std::vector ret; - - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) - { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) - { - for (i = 0; i <4; i++) - { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) - { - ret.push_back(char_array_3[i]); - } - i = 0; - } - } - - if (i) - { - for (j = i; j <4; j++) - { - char_array_4[j] = 0; - } - - for (j = 0; j <4; j++) - { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; (j < i - 1); j++) - { - ret.push_back(char_array_3[j]); - } - } - - return ret; - -} - - - // // tokenizer and input processing utils // @@ -539,7 +139,6 @@ static json json_get_nested_values(const std::vector & paths, const return result; } - /** * this handles 2 cases: * - only string, example: "string" @@ -623,8 +222,781 @@ static std::vector tokenize_input_prompts(const llama_vocab * voca return result; } +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + +// +// template utils +// + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { + llama_tokens result; + + result.reserve(doc.size() + query.size() + 4); + result.push_back(llama_vocab_bos(vocab)); + result.insert(result.end(), query.begin(), query.end()); + result.push_back(llama_vocab_eos(vocab)); + result.push_back(llama_vocab_sep(vocab)); + result.insert(result.end(), doc.begin(), doc.end()); + result.push_back(llama_vocab_eos(vocab)); + + return result; +} + +// format infill task +static llama_tokens format_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; +} + +// +// base64 utils (TODO: move to common in the future) +// + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(uint8_t c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +static inline raw_buffer base64_decode(const std::string & encoded_string) { + int i = 0; + int j = 0; + int in_ = 0; + + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + raw_buffer ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) { + ret.push_back(char_array_3[i]); + } + + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } + + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret.push_back(char_array_3[j]); + } + } + + return ret; +} + +// +// random string / id +// + +static std::string random_string() { + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + std::random_device rd; + std::mt19937 generator(rd()); + + std::string result(32, ' '); + + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; + } + + return result; +} + +static std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); +} + +static std::string gen_tool_call_id() { + return random_string(); +} + +// +// other common utils +// + +static bool ends_with(const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { + if (!text.empty() && !stop.empty()) { + const char text_last_char = text.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { + const std::string current_partial = stop.substr(0, char_index + 1); + if (ends_with(text, current_partial)) { + return text.size() - char_index - 1; + } + } + } + } + + return std::string::npos; +} + +// TODO: reuse llama_detokenize +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); + } + + return ret; +} + +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); + + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + + return out; +} +// +// OAI utils +// + +static json oaicompat_completion_params_parse(const json & body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "echo" field + if (json_value(body, "echo", false)) { + throw std::runtime_error("Only no echo is supported"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params { "best_of", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } + } + + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json oaicompat_completion_params_parse( + const json & body, /* openai api json semantics */ + bool use_jinja, + common_reasoning_format reasoning_format, + const struct common_chat_templates * tmpls, + bool allow_non_text, + std::vector & out_files) +{ + json llama_params; + + auto tools = json_value(body, "tools", json()); + auto stream = json_value(body, "stream", false); + + if (tools.is_array() && !tools.empty()) { + if (stream) { + throw std::runtime_error("Cannot use tools with stream"); + } + if (!use_jinja) { + throw std::runtime_error("tools param requires --jinja flag"); + } + } + if (!use_jinja) { + if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { + throw std::runtime_error("Unsupported param: tool_choice"); + } + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + + // Handle "response_format" field + if (body.contains("response_format")) { + json response_format = json_value(body, "response_format", json::object()); + std::string response_type = json_value(response_format, "type", std::string()); + if (response_type == "json_object") { + json_schema = json_value(response_format, "schema", json::object()); + } else if (response_type == "json_schema") { + auto schema_wrapper = json_value(response_format, "json_schema", json::object()); + json_schema = json_value(schema_wrapper, "schema", json::object()); + } else if (!response_type.empty() && response_type != "text") { + throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); + } + } + + // get input files + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json messages = body.at("messages"); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array"); + } + for (auto & msg : messages) { + json & content = msg.at("content"); + if (content.is_string() || content.is_null()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto & p : content) { + std::string type = json_value(p, "type", std::string()); + json image_url = json_value(p, "image_url", json::object()); + if (type == "image_url") { + if (!allow_non_text) { + throw std::runtime_error("image input is not supported by this server"); + } + + std::string url = json_value(image_url, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %ld bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = MTMD_DEFAULT_IMAGE_MARKER; + p.erase("image_url"); + } + } + } + + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + inputs.use_jinja = use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + + // if the assistant message appears at the end of list, we do not add end-of-turn token + // for ex. this can be useful to modify the reasoning process in reasoning models + bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant"; + common_chat_msg last_message; + if (prefill_assistant_message) { + last_message = inputs.messages.back(); + inputs.messages.pop_back(); + + /* sanity check, max one assistant message at the end of the list */ + if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ + throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); + } + + inputs.extract_reasoning = false; + inputs.add_generation_prompt = true; + } + + // Apply chat template to the list of messages + auto chat_params = common_chat_templates_apply(tmpls, inputs); + + /* Append assistant prefilled message */ + if (prefill_assistant_message) { + chat_params.prompt += last_message.content; + } + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + if (!chat_params.grammar.empty()) { + llama_params["grammar"] = chat_params.grammar; + } + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "logprobs" field + // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future + if (json_value(body, "logprobs", false)) { + llama_params["n_probs"] = json_value(body, "top_logprobs", 20); + } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { + throw std::runtime_error("top_logprobs requires logprobs to be set to true"); + } + + // Copy remaining properties to llama_params + // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. + // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); + + n_tokens += json_value(elem, "tokens_evaluated", 0); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"data", data} + }; + + return res; +} + +static json format_response_rerank( + const json & request, + const json & ranks, + bool is_tei_format, + std::vector & texts) { + json res; + if (is_tei_format) { + // TEI response format + res = json::array(); + bool return_text = json_value(request, "return_text", false); + for (const auto & rank : ranks) { + int index = json_value(rank, "index", 0); + json elem = json{ + {"index", index}, + {"score", json_value(rank, "score", 0.0)}, + }; + if (return_text) { + elem["text"] = std::move(texts[index]); + } + res.push_back(elem); + } + } else { + // Jina response format + json results = json::array(); + int32_t n_tokens = 0; + for (const auto & rank : ranks) { + results.push_back(json{ + {"index", json_value(rank, "index", 0)}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); + } + + res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{ + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", results} + }; + } + + return res; +} + +static bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } + } + + return true; +} + +static json format_tokenizer_response(const json & tokens) { + return json { + {"tokens", tokens} + }; +} + +static json format_detokenized_response(const std::string & content) { + return json { + {"content", content} + }; +} + +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + // sort tokens by logits + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } + + return cur; +} + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +} // // utils for interacting with libmtmd @@ -875,7 +1247,7 @@ public: if (it == map_pos_to_image.end()) { throw std::runtime_error("Chunk not found"); } - // SRV_INF("%s\n", "processing image..."); + SRV_INF("%s\n", "processing image..."); int32_t n_batch = llama_n_batch(ctx); int64_t t0 = ggml_time_ms(); llama_pos new_n_past = n_past; @@ -886,7 +1258,7 @@ public: n_batch, true, // logits last &new_n_past); - //SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0); + SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0); if (result != 0) { LOG_ERR("mtmd_helper_eval failed with status %d", result); n_pos_out = n_past; @@ -907,4 +1279,4 @@ static std::string fnv_hash(const uint8_t * data, size_t len) { hash *= fnv_prime; } return std::to_string(hash); -} \ No newline at end of file +}