This commit is contained in:
Ettore Di Giacinto 2025-05-14 20:11:06 +02:00
parent 029f97c2a2
commit 7437d0c9ca
6 changed files with 552 additions and 385 deletions

View file

@ -6,7 +6,7 @@ BINARY_NAME=local-ai
DETECT_LIBS?=true
# llama.cpp versions
CPPLLAMA_VERSION?=de4c07f93783a1a96456a44dc16b9db538ee1618
CPPLLAMA_VERSION?=e5c834f718a32b7584f142799bbf508fddb9021c
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp

View file

@ -1,17 +1,17 @@
## XXX: In some versions of CMake clip wasn't being built before llama.
## This is an hack for now, but it should be fixed in the future.
set(TARGET myclip)
add_library(${TARGET} clip.cpp clip.h clip-impl.h llava.cpp llava.h)
install(TARGETS ${TARGET} LIBRARY)
target_include_directories(myclip PUBLIC .)
target_include_directories(myclip PUBLIC ../..)
target_include_directories(myclip PUBLIC ../../common)
target_link_libraries(${TARGET} PRIVATE common ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if (NOT MSVC)
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
endif()
# set(TARGET myclip)
# add_library(${TARGET} clip.cpp clip.h clip-impl.h llava.cpp llava.h)
# install(TARGETS ${TARGET} LIBRARY)
# target_include_directories(myclip PUBLIC .)
# target_include_directories(myclip PUBLIC ../..)
# target_include_directories(myclip PUBLIC ../../common)
# target_link_libraries(${TARGET} PRIVATE common ggml llama ${CMAKE_THREAD_LIBS_INIT})
# target_compile_features(${TARGET} PRIVATE cxx_std_11)
# if (NOT MSVC)
# target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
# endif()
# END CLIP hack
@ -75,7 +75,11 @@ add_library(hw_grpc_proto
${hw_proto_hdrs} )
add_executable(${TARGET} grpc-server.cpp utils.hpp json.hpp)
target_link_libraries(${TARGET} PRIVATE common llama myclip ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
target_include_directories(${TARGET} PRIVATE ../llava)
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common llama mtmd ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
absl::flags_parse
gRPC::${_REFLECTION}
gRPC::${_GRPC_GRPCPP}

View file

@ -11,8 +11,7 @@
#include <memory>
#include <string>
#include <getopt.h>
#include "clip.h"
#include "llava.h"
#include "mtmd.h"
#include "log.h"
#include "stb_image.h"
#include "common.h"
@ -210,6 +209,8 @@ struct llama_client_slot
int32_t num_prompt_tokens_processed = 0;
json prompt;
json data;
std::string generated_text;
llama_token sampled;
std::vector<llama_token> cache_tokens;
@ -239,7 +240,7 @@ struct llama_client_slot
int32_t n_past_se = 0; // self-extend
// multimodal
std::vector<slot_image> images;
mtmd_context * mctx = nullptr;
// stats
size_t sent_count = 0;
@ -270,17 +271,6 @@ struct llama_client_slot
n_past_se = 0;
generated_token_probs.clear();
for (slot_image & img : images)
{
free(img.image_embedding);
if (img.img_data) {
clip_image_u8_free(img.img_data);
}
img.prefix_prompt = "";
}
images.clear();
}
bool has_budget(common_params &global_params) {
@ -456,6 +446,9 @@ struct llama_server_context
llama_context *ctx = nullptr;
const llama_vocab * vocab = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
clip_ctx *clp_ctx = nullptr;
common_params params;
@ -494,6 +487,10 @@ struct llama_server_context
~llama_server_context()
{
if (mctx) {
mtmd_free(mctx);
mctx = nullptr;
}
if (ctx)
{
llama_free(ctx);
@ -512,12 +509,14 @@ struct llama_server_context
if (!params.mmproj.path.empty()) {
multimodal = true;
LOG_INFO("Multi Modal Mode Enabled", {});
clp_ctx = clip_init(params.mmproj.path.c_str(), clip_context_params {
/* use_gpu */ has_gpu,
/*verbosity=*/ GGML_LOG_LEVEL_INFO,
});
if(clp_ctx == nullptr) {
LOG_ERR("unable to load clip model: %s", params.mmproj.path.c_str());
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = has_gpu;
mparams.print_timings = false;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = GGML_LOG_LEVEL_INFO;
mctx = mtmd_init_from_file(params.mmproj.path.c_str(), model, mparams);
if (mctx == nullptr) {
LOG_ERR("failed to load multimodal model, '%s'\n", params.mmproj.path.c_str());
return false;
}
@ -579,6 +578,8 @@ struct llama_server_context
slot.id = i;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
slot.mctx = mctx;
//slot.cache_tokens.has_mtmd = mctx != nullptr;
LOG_INFO("new slot", {
{"slot_id", slot.id},
@ -616,54 +617,61 @@ struct llama_server_context
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
}
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
std::vector<server_tokens> tokenize(json &data, const json & json_prompt, bool add_bos) const
{
// TODO: currently, we tokenize using special tokens by default
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
// but it's better compared to completely ignoring ChatML and other chat templates
const bool TMP_FORCE_SPECIAL = true;
mtmd::bitmaps bitmaps;
std::vector<server_tokens> inputs;
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
if (mctx != nullptr)
{
const auto &images_data = data.find("image_data");
if (images_data != data.end() && images_data->is_array())
{
for (const auto &img : *images_data)
{
const std::vector<uint8_t> image_buffer = base64_decode(img["data"].get<std::string>());
if (json_prompt.is_array())
{
bool first = true;
for (const auto& p : json_prompt)
{
if (p.is_string())
{
auto s = p.template get<std::string>();
std::vector<llama_token> p;
if (first)
{
p = common_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
first = false;
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(image_buffer.data(), image_buffer.size()));
if (!bmp.ptr) {
throw std::runtime_error("Failed to load image");
}
else
{
p = common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
// calculate bitmap hash (for KV caching)
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
bmp.set_id(hash.c_str());
bitmaps.entries.push_back(std::move(bmp));
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
}
else
{
if (first)
{
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
}
else
{
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
}
return prompt_tokens;
// multimodal
std::string prompt_str = json_prompt.template get<std::string>();
mtmd_input_text inp_txt = {
prompt_str.c_str(),
/* add_special */ true,
/* parse_special */ true,
};
mtmd::input_chunks chunks(mtmd_input_chunks_init());
auto bitmaps_c_ptr = bitmaps.c_ptr();
int32_t tokenized = mtmd_tokenize(mctx,
chunks.ptr.get(),
&inp_txt,
bitmaps_c_ptr.data(),
bitmaps_c_ptr.size());
if (tokenized != 0) {
throw std::runtime_error("Failed to tokenize prompt");
}
server_tokens tmp(chunks, true);
inputs.push_back(std::move(tmp));
} else {
// non-multimodal version
auto tokenized_prompts = tokenize_input_prompts(vocab, json_prompt, true, true);
for (auto & p : tokenized_prompts) {
auto tmp = server_tokens(p, mctx != nullptr);
inputs.push_back(std::move(tmp));
}
}
return inputs;
}
llama_client_slot* get_slot(int id) {
@ -716,6 +724,8 @@ struct llama_server_context
slot->sparams.grammar_triggers = grammar_triggers;
slot->sparams.grammar_lazy = grammar_lazy;
slot->data = data;
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", {
@ -757,43 +767,7 @@ struct llama_server_context
if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot->sparams.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
}
/*
slot->sparams.penalty_prompt_tokens.clear();
slot->sparams.use_penalty_prompt_tokens = false;
const auto &penalty_prompt = data.find("penalty_prompt");
if (penalty_prompt != data.end())
{
if (penalty_prompt->is_string())
{
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
slot->sparams.penalty_prompt_tokens.swap(penalty_tokens);
if (slot->params.n_predict > 0)
{
slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict);
}
slot->sparams.use_penalty_prompt_tokens = true;
}
else if (penalty_prompt->is_array())
{
const auto n_tokens = penalty_prompt->size();
slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict));
const int n_vocab = llama_n_vocab(model);
for (const auto &penalty_token : *penalty_prompt)
{
if (penalty_token.is_number_integer())
{
const auto tok = penalty_token.get<llama_token>();
if (tok >= 0 && tok < n_vocab)
{
slot->sparams.penalty_prompt_tokens.push_back(tok);
}
}
}
slot->sparams.use_penalty_prompt_tokens = true;
}
}
*/
slot->sparams.logit_bias.clear();
const auto &logit_bias = data.find("logit_bias");
@ -869,79 +843,6 @@ struct llama_server_context
}
if (multimodal)
{
const auto &images_data = data.find("image_data");
if (images_data != data.end() && images_data->is_array())
{
for (const auto &img : *images_data)
{
const std::vector<uint8_t> image_buffer = base64_decode(img["data"].get<std::string>());
slot_image img_sl;
img_sl.id = img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
img_sl.img_data = clip_image_u8_init();
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
{
LOG_ERR("%s: failed to load image, slot_id: %d, img_sl_id: %d",
__func__,
slot->id,
img_sl.id
);
return false;
}
LOG_VERBOSE("image loaded", {
{"slot_id", slot->id},
{"img_sl_id", img_sl.id}
});
img_sl.request_encode_image = true;
slot->images.push_back(img_sl);
}
// process prompt
// example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]}
if (slot->images.size() > 0 && !slot->prompt.is_array())
{
std::string prompt = slot->prompt.get<std::string>();
size_t pos = 0, begin_prefix = 0;
std::string pattern = "[img-";
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
size_t end_prefix = pos;
pos += pattern.length();
size_t end_pos = prompt.find(']', pos);
if (end_pos != std::string::npos)
{
std::string image_id = prompt.substr(pos, end_pos - pos);
try
{
int img_id = std::stoi(image_id);
bool found = false;
for (slot_image &img : slot->images)
{
if (img.id == img_id) {
found = true;
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
begin_prefix = end_pos + 1;
break;
}
}
if (!found) {
LOG("ERROR: Image with id: %i, not found.\n", img_id);
slot->images.clear();
return false;
}
} catch (const std::invalid_argument& e) {
LOG("Invalid image number id in prompt\n");
slot->images.clear();
return false;
}
}
}
slot->prompt = "";
slot->params.input_suffix = prompt.substr(begin_prefix);
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
}
}
}
if (slot->ctx_sampling != nullptr)
{
@ -1189,26 +1090,6 @@ struct llama_server_context
return slot.has_next_token; // continue
}
bool process_images(llama_client_slot &slot) const
{
for (slot_image &img : slot.images)
{
if (!img.request_encode_image)
{
continue;
}
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.cpuparams.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
LOG("Error processing the given image");
return false;
}
img.request_encode_image = false;
}
return slot.images.size() > 0;
}
void send_error(task_server& task, const std::string &error)
{
LOG("task %i - error: %s\n", task.id, error.c_str());
@ -1451,74 +1332,6 @@ struct llama_server_context
}
}
// for multiple images processing
bool ingest_images(llama_client_slot &slot, int n_batch)
{
int image_idx = 0;
while (image_idx < (int) slot.images.size())
{
slot_image &img = slot.images[image_idx];
// process prefix prompt
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
if (llama_decode(ctx, batch_view))
{
LOG("%s : failed to eval\n", __func__);
return false;
}
}
// process image with llm
for (int i = 0; i < img.image_tokens; i += n_batch)
{
int n_eval = img.image_tokens - i;
if (n_eval > n_batch)
{
n_eval = n_batch;
}
const int n_embd = llama_model_n_embd(model);
float * embd = img.image_embedding + i * n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, slot.n_past, 0);
if (llama_decode(ctx, llava_batch.batch))
{
LOG("%s : failed to eval image\n", __func__);
return false;
}
slot.n_past += n_eval;
}
image_idx++;
common_batch_clear(batch);
// append prefix of next image
const auto json_prompt = (image_idx >= (int) slot.images.size()) ?
slot.params.input_suffix : // no more images, then process suffix prompt
(json)(slot.images[image_idx].prefix_prompt);
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < (int) append_tokens.size(); ++i)
{
common_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
slot.n_past += 1;
}
}
return true;
}
void request_cancel(int task_id)
{
task_server task;
@ -1733,7 +1546,7 @@ struct llama_server_context
{
for (auto & slot : slots)
{
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
// empty prompt passed -> release the slot and send empty response
// note: infill mode allows empty prompt
@ -1750,7 +1563,7 @@ struct llama_server_context
{
slot.state = PROCESSING;
slot.command = NONE;
std::vector<llama_token> prompt_tokens;
std::vector<server_tokens> prompt_tokens;
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0;
@ -1762,8 +1575,8 @@ struct llama_server_context
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
auto prefix_tokens = tokenize(slot.data, slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.data, slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
@ -1779,7 +1592,7 @@ struct llama_server_context
}
else
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
prompt_tokens = tokenize(slot.data, slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
slot.num_prompt_tokens = prompt_tokens.size();
@ -1892,18 +1705,36 @@ struct llama_server_context
});
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
// process the prefix of first image
std::vector<server_tokens> prefix_tokens = prompt_tokens;
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
// check if we should process the image
if (slot.n_past < slot.n_prompt_tokens
&& slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
// process the image
int32_t new_n_past;
int32_t res = prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
int32_t n_pos = new_n_past - slot.n_past;
if (res != 0) {
slot.release();
LOG_ERR("failed to process image, res = %d\n", res);
continue;
}
slot.n_past += n_pos;
// slot.n_prompt_tokens_processed += n_pos;
}
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
});
const bool has_images = process_images(slot);
// process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
@ -1923,19 +1754,6 @@ struct llama_server_context
slot_npast++;
}
if (has_images && !ingest_images(slot, n_batch))
{
LOG_ERR("%s: failed processing images Slot id : %d, Task id: %d",
__func__,
slot.id,
slot.task_id
);
// FIXME @phymbert: to be properly tested
// early returning without changing the slot state will block the slot for ever
// no one at the moment is checking the return value
return false;
}
// extract the logits only for the last token
if (batch.n_tokens > 0)
{
@ -2164,26 +1982,6 @@ static void start_llama_server() {
json parse_options(bool streaming, const backend::PredictOptions* predict, llama_server_context &llama)
{
// This is for example a slot data from the json data
// slot->params.stream = json_value(data, "stream", false);
// slot->params.cache_prompt = json_value(data, "cache_prompt", false);
// slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
// slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
// slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
// slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
// slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
// slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
// slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
// slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
// slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
// slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
// slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
// slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
// slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
// slot->params.seed = json_value(data, "seed", default_params.seed);
// slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
// slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
// Create now a json data from the prediction options instead
//
json data;
@ -2228,69 +2026,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
return data;
}
// static void parse_options_completion(bool streaming,const backend::PredictOptions* predict, llama_server_context &llama)
// {
// // https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L673
// gpt_params default_params;
// llama.stream = streaming;
// llama.params.n_predict = predict->tokens() == 0 ? -1 : predict->tokens();
// llama.params.sparams.top_k = predict->topk();
// llama.params.sparams.top_p = predict->topp();
// llama.params.sparams.typical_p = predict->typicalp();
// llama.params.sparams.penalty_last_n = predict->repeat();
// llama.params.sparams.temp = predict->temperature();
// llama.params.sparams.penalty_repeat = predict->penalty();
// llama.params.sparams.penalty_present = predict->presencepenalty();
// llama.params.sparams.penalty_freq = predict->frequencypenalty();
// llama.params.sparams.mirostat = predict->mirostat();
// llama.params.sparams.mirostat_tau = predict->mirostattau();
// llama.params.sparams.mirostat_eta = predict->mirostateta();
// llama.params.n_keep = predict->nkeep();
// llama.params.seed = predict->seed();
// llama.params.sparams.grammar = predict->grammar();
// // llama.params.n_probs = predict->
// llama.params.prompt = predict->prompt();
// llama.params.sparams.logit_bias.clear();
// if (predict->ignoreeos())
// {
// llama.params.sparams.logit_bias[llama_token_eos(llama.model)] = -INFINITY;
// }
// // const auto &logit_bias = body.find("logit_bias");
// // if (logit_bias != body.end() && logit_bias->is_array())
// // {
// // const int n_vocab = llama_n_vocab(llama.model);
// // for (const auto &el : *logit_bias)
// // {
// // if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
// // {
// // llama_token tok = el[0].get<llama_token>();
// // if (tok >= 0 && tok < n_vocab)
// // {
// // if (el[1].is_number())
// // {
// // llama.params.logit_bias[tok] = el[1].get<float>();
// // }
// // else if (el[1].is_boolean() && !el[1].get<bool>())
// // {
// // llama.params.logit_bias[tok] = -INFINITY;
// // }
// // }
// // }
// // }
// // }
// llama.params.antiprompt.clear();
// for (const std::string& stopPrompt : predict->stopprompts()) {
// if (!stopPrompt.empty())
// {
// llama.params.antiprompt.push_back(stopPrompt);
// }
// }
// }
const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_F32,
@ -2603,10 +2338,10 @@ public:
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response){
json data = parse_options(false, request, llama);
std::vector<llama_token> tokens = llama.tokenize(data["prompt"],false);
std::vector<server_tokens> tokens = llama.tokenize(data, data["prompt"],false);
for (int i=0 ; i< tokens.size(); i++){
response->add_tokens(tokens[i]);
response->add_tokens(tokens[i].llama_token);
}
return grpc::Status::OK;

View file

@ -20,9 +20,9 @@ fi
## XXX: In some versions of CMake clip wasn't being built before llama.
## This is an hack for now, but it should be fixed in the future.
cp -rfv llama.cpp/tools/mtmd/clip.h llama.cpp/tools/grpc-server/clip.h
cp -rfv llama.cpp/tools/mtmd/clip-impl.h llama.cpp/tools/grpc-server/clip-impl.h
cp -rfv llama.cpp/tools/mtmd/llava.cpp llama.cpp/tools/grpc-server/llava.cpp
echo '#include "llama.h"' > llama.cpp/tools/grpc-server/llava.h
cat llama.cpp/tools/mtmd/llava.h >> llama.cpp/tools/grpc-server/llava.h
cp -rfv llama.cpp/tools/mtmd/clip.cpp llama.cpp/tools/grpc-server/clip.cpp
# cp -rfv llama.cpp/tools/mtmd/clip.h llama.cpp/tools/grpc-server/clip.h
# cp -rfv llama.cpp/tools/mtmd/clip-impl.h llama.cpp/tools/grpc-server/clip-impl.h
# cp -rfv llama.cpp/tools/mtmd/llava.cpp llama.cpp/tools/grpc-server/llava.cpp
# echo '#include "llama.h"' > llama.cpp/tools/grpc-server/llava.h
# cat llama.cpp/tools/mtmd/llava.h >> llama.cpp/tools/grpc-server/llava.h
# cp -rfv llama.cpp/tools/mtmd/clip.cpp llama.cpp/tools/grpc-server/clip.cpp

View file

@ -480,4 +480,431 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
}
return ret;
}
//
// tokenizer and input processing utils
//
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number_integer()) {
return false;
}
}
return true;
}
return false;
}
// is array having BOTH numbers & strings?
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
bool seen_string = false;
bool seen_number = false;
if (data.is_array()) {
for (const auto & e : data) {
seen_string |= e.is_string();
seen_number |= e.is_number_integer();
if (seen_number && seen_string) {
return true;
}
}
}
return false;
}
// get value by path(key1 / key2)
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
json result = json::object();
for (const std::string & path : paths) {
json current = js;
const auto keys = string_split<std::string>(path, /*separator*/ '/');
bool valid_path = true;
for (const std::string & k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
valid_path = false;
}
}
if (valid_path) {
result[path] = current;
}
}
return result;
}
/**
* this handles 2 cases:
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
llama_tokens p;
if (first) {
p = common_tokenize(vocab, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(vocab, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
}
return prompt_tokens;
}
/**
* break the input "prompt" object into multiple prompt if needed, then tokenize them
* this supports these cases:
* - "prompt": "string"
* - "prompt": [12, 34, 56]
* - "prompt": [12, 34, "string", 56, 78]
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens
result.push_back(json_prompt.get<llama_tokens>());
} else if (json_prompt.is_array()) {
// array of prompts
result.reserve(json_prompt.size());
for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get<llama_tokens>());
} else {
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
}
}
} else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
}
if (result.empty()) {
throw std::runtime_error("\"prompt\" must not be empty");
}
return result;
}
//
// utils for interacting with libmtmd
// (may need to refactor in near future)
//
/**
* server_tokens is a helper to manage the input tokens and image for the server.
* it is made this way to simplify the logic of KV cache management.
*/
struct server_tokens {
bool has_mtmd = false;
private: // disallow accessing these members directly, risking out-of-sync
// map a **start** position in tokens to the image chunk
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_image;
// list of tokens
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
// a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
// important: for models using mrope, an image can contain multiple tokens but will use only one **position**
llama_tokens tokens;
// for ex. with input of 5 text tokens and 2 images:
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
// pos 0 1 2 3 4 5 6 7 8 9
// map_pos_to_image will contain: {5, img0}, {8, img1}
public:
server_tokens() = default;
~server_tokens() = default;
// Prevent copying
server_tokens(const server_tokens&) = delete;
server_tokens& operator=(const server_tokens&) = delete;
// Allow moving (usually implicitly generated if members are movable)
server_tokens(server_tokens&&) = default;
server_tokens& operator=(server_tokens&&) = default;
// Allow accessing elements using [] operator
llama_token operator[](size_t index) { return tokens[index]; }
const llama_token& operator[](size_t index) const { return tokens[index]; }
server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
for (size_t i = 0; i < mtmd_chunks.size(); ++i) {
push_back(mtmd_chunks[i]);
}
}
server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
// for debugging
std::string str() const {
std::ostringstream oss;
oss << "tokens: ";
for (const auto & t : tokens) {
if (t == LLAMA_TOKEN_NULL) {
oss << "<embd> ";
} else {
oss << t << " ";
}
}
oss << "\n";
oss << "image pos: ";
for (const auto & it : map_pos_to_image) {
oss << it.first << ", ";
}
return oss.str();
}
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
auto it = map_pos_to_image.find(pos);
if (it != map_pos_to_image.end()) {
return it->second;
} else {
throw std::runtime_error("Chunk not found");
}
}
void push_back(llama_token tok) {
if (tok == LLAMA_TOKEN_NULL) {
throw std::runtime_error("Invalid token");
}
tokens.emplace_back(tok);
}
// will create a copy of the chunk if it contains non-text data
void push_back(const mtmd_input_chunk * chunk) {
auto type = mtmd_input_chunk_get_type(chunk);
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
GGML_ASSERT(has_mtmd);
auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk);
const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
llama_pos start_pos = tokens.size();
for (int i = 0; i < n_pos; ++i) {
tokens.emplace_back(LLAMA_TOKEN_NULL);
}
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_pos_to_image[start_pos] = std::move(new_chunk);
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
for (size_t i = 0; i < n_tokens; ++i) {
push_back(text_tokens[i]);
}
} else {
GGML_ABORT("Invalid chunk type");
}
}
// for compatibility with context shift and prompt truncation
void insert(const llama_tokens & inp_tokens) {
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
}
// for compatibility with speculative decoding, ctx shift, slot save/load
const llama_tokens & get_text_tokens() const {
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
return tokens;
}
// for compatibility with speculative decoding
void set_token(llama_pos pos, llama_token id) {
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
tokens[pos] = id;
}
size_t size() const {
return tokens.size();
}
bool empty() const {
return tokens.empty();
}
void clear() {
tokens.clear();
}
void resize(size_t n) {
GGML_ASSERT(n <= tokens.size());
if (has_mtmd) {
// we throw an error if we try to remove a token in the middle of an image
// for ex. with input of 5 text tokens and 2 images:
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
// n 1 2 3 4 5 6 7 8 9 10
// allowed to resize ^ ^
// disallowed to resize ^ ^ ^
if (n > 0) {
llama_token last_token = tokens[n - 1];
// make sure we never remove tokens in the middle of an image
if (last_token == LLAMA_TOKEN_NULL) {
find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
}
}
// remove all image chunks that are not used anymore
for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) {
llama_pos pos = it->first;
if (pos >= (llama_pos)n) {
it = map_pos_to_image.erase(it);
} else {
++it;
}
}
}
tokens.resize(n);
}
std::string detokenize(const llama_context * ctx, bool special) const {
llama_tokens text_tokens;
text_tokens.reserve(tokens.size());
for (const auto & t : tokens) {
if (t != LLAMA_TOKEN_NULL) {
text_tokens.push_back(t);
}
}
return common_detokenize(ctx, text_tokens, special);
}
size_t get_common_prefix(const server_tokens & b) const {
size_t max_idx = std::min(tokens.size(), b.tokens.size());
for (size_t i = 0; i < max_idx; ++i) {
auto & ai = tokens[i];
auto & bi = b.tokens[i];
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
GGML_ASSERT(has_mtmd);
const auto & a_chunk = find_chunk(i);
const auto & b_chunk = b.find_chunk(i);
GGML_ASSERT(a_chunk && b_chunk);
const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get());
const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get());
std::string ai_id = mtmd_image_tokens_get_id(a_img);
std::string bi_id = mtmd_image_tokens_get_id(b_img);
size_t a_pos = mtmd_image_tokens_get_n_pos(a_img);
size_t b_pos = mtmd_image_tokens_get_n_pos(b_img);
if (ai_id == bi_id && a_pos == b_pos) {
GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
i += a_pos - 1; // will be +1 by the for loop
continue;
} else {
return i;
}
} else if (ai == bi) {
continue;
} else {
return i;
}
}
return max_idx; // all tokens are equal
}
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
for (size_t i = 0; i < tokens.size(); ++i) {
auto & t = tokens[i];
if (t == LLAMA_TOKEN_NULL) {
try {
const auto & chunk = find_chunk(i);
const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
i += n_pos - 1; // will be +1 by the for loop
} catch (const std::exception & e) {
return false;
}
} else if (t < 0 || t >= n_vocab) {
return false;
}
}
return true;
}
// encode and decode the image chunk
int32_t process_chunk(
llama_context * ctx,
mtmd_context * mctx,
llama_pos n_past,
int32_t seq_id,
llama_pos & n_pos_out) {
auto it = map_pos_to_image.find(n_past);
if (it == map_pos_to_image.end()) {
throw std::runtime_error("Chunk not found");
}
// SRV_INF("%s\n", "processing image...");
int32_t n_batch = llama_n_batch(ctx);
int64_t t0 = ggml_time_ms();
llama_pos new_n_past = n_past;
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
it->second.get(), // chunk
n_past,
seq_id,
n_batch,
true, // logits last
&new_n_past);
//SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result);
n_pos_out = n_past;
return result;
}
n_pos_out = new_n_past;
return 0;
}
};
// Computes FNV-1a hash of the data
static std::string fnv_hash(const uint8_t * data, size_t len) {
const uint64_t fnv_prime = 0x100000001b3ULL;
uint64_t hash = 0xcbf29ce484222325ULL;
for (size_t i = 0; i < len; ++i) {
hash ^= data[i];
hash *= fnv_prime;
}
return std::to_string(hash);
}

View file

@ -21,7 +21,8 @@ type MultimodalContent struct {
ID int
}
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}[img-{{.ID}}]{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
// https://github.com/ggml-org/llama.cpp/blob/be1d4a13db26750fac702ceb3af88ae4f39dc9f4/tools/mtmd/mtmd.h#L42
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}<__image__>{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
if templateString == "" {