This commit is contained in:
Ettore Di Giacinto 2025-05-14 22:57:56 +02:00
parent 7437d0c9ca
commit cd4c0b8aa6

View file

@ -1579,15 +1579,32 @@ struct llama_server_context
auto suffix_tokens = tokenize(slot.data, slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0][0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_vocab_fim_pre(vocab));
prefix_tokens.insert(prefix_tokens.begin(), llama_vocab_bos(vocab)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_vocab_fim_suf(vocab));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_vocab_fim_mid(vocab));
// Create llama_tokens vectors for the special tokens
llama_tokens fim_pre_tokens;
fim_pre_tokens.push_back(llama_vocab_fim_pre(vocab));
llama_tokens bos_tokens;
bos_tokens.push_back(llama_vocab_bos(vocab));
llama_tokens fim_suf_tokens;
fim_suf_tokens.push_back(llama_vocab_fim_suf(vocab));
llama_tokens fim_mid_tokens;
fim_mid_tokens.push_back(llama_vocab_fim_mid(vocab));
// Create server_tokens objects
server_tokens fim_pre_token(fim_pre_tokens, mctx != nullptr);
server_tokens bos_token(bos_tokens, mctx != nullptr);
server_tokens fim_suf_token(fim_suf_tokens, mctx != nullptr);
server_tokens fim_mid_token(fim_mid_tokens, mctx != nullptr);
// Insert tokens in the correct order
prefix_tokens.insert(prefix_tokens.begin(), fim_pre_token);
prefix_tokens.insert(prefix_tokens.begin(), bos_token); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), fim_suf_token);
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(fim_mid_token);
prompt_tokens = prefix_tokens;
}
else
@ -1620,7 +1637,12 @@ struct llama_server_context
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
slot.truncated = true;
prompt_tokens = new_tokens;
// Convert new_tokens to server_tokens
std::vector<server_tokens> new_prompt_tokens;
server_tokens new_server_tokens(new_tokens, mctx != nullptr);
new_prompt_tokens.push_back(std::move(new_server_tokens));
prompt_tokens = std::move(new_prompt_tokens);
slot.num_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
@ -1640,10 +1662,17 @@ struct llama_server_context
// push the prompt into the sampling context (do not apply grammar)
for (auto &token : prompt_tokens)
{
common_sampler_accept(slot.ctx_sampling, token, false);
// Convert server_tokens to llama_token for sampling
llama_token tok = token[0]; // Get first token
common_sampler_accept(slot.ctx_sampling, tok, false);
}
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
// Convert server_tokens to llama_tokens for comparison
std::vector<llama_token> prompt_llama_tokens;
for (const auto &token : prompt_tokens) {
prompt_llama_tokens.push_back(token[0]);
}
slot.n_past = common_part(slot.cache_tokens, prompt_llama_tokens);
// the last token of the cache is not in the KV cache until the next call to llama_decode
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
@ -1681,7 +1710,12 @@ struct llama_server_context
});
}
slot.cache_tokens = prompt_tokens;
// Convert server_tokens to llama_tokens for cache
std::vector<llama_token> cache_llama_tokens;
for (const auto &token : prompt_tokens) {
cache_llama_tokens.push_back(token[0]);
}
slot.cache_tokens = cache_llama_tokens;
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
{