mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
wip
This commit is contained in:
parent
7437d0c9ca
commit
cd4c0b8aa6
1 changed files with 44 additions and 10 deletions
|
@ -1579,15 +1579,32 @@ struct llama_server_context
|
||||||
auto suffix_tokens = tokenize(slot.data, slot.params.input_suffix, false);
|
auto suffix_tokens = tokenize(slot.data, slot.params.input_suffix, false);
|
||||||
|
|
||||||
const int space_token = 29871; // TODO: this should not be hardcoded
|
const int space_token = 29871; // TODO: this should not be hardcoded
|
||||||
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0][0] == space_token) {
|
||||||
suffix_tokens.erase(suffix_tokens.begin());
|
suffix_tokens.erase(suffix_tokens.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_vocab_fim_pre(vocab));
|
// Create llama_tokens vectors for the special tokens
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_vocab_bos(vocab)); // always add BOS
|
llama_tokens fim_pre_tokens;
|
||||||
prefix_tokens.insert(prefix_tokens.end(), llama_vocab_fim_suf(vocab));
|
fim_pre_tokens.push_back(llama_vocab_fim_pre(vocab));
|
||||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
llama_tokens bos_tokens;
|
||||||
prefix_tokens.push_back(llama_vocab_fim_mid(vocab));
|
bos_tokens.push_back(llama_vocab_bos(vocab));
|
||||||
|
llama_tokens fim_suf_tokens;
|
||||||
|
fim_suf_tokens.push_back(llama_vocab_fim_suf(vocab));
|
||||||
|
llama_tokens fim_mid_tokens;
|
||||||
|
fim_mid_tokens.push_back(llama_vocab_fim_mid(vocab));
|
||||||
|
|
||||||
|
// Create server_tokens objects
|
||||||
|
server_tokens fim_pre_token(fim_pre_tokens, mctx != nullptr);
|
||||||
|
server_tokens bos_token(bos_tokens, mctx != nullptr);
|
||||||
|
server_tokens fim_suf_token(fim_suf_tokens, mctx != nullptr);
|
||||||
|
server_tokens fim_mid_token(fim_mid_tokens, mctx != nullptr);
|
||||||
|
|
||||||
|
// Insert tokens in the correct order
|
||||||
|
prefix_tokens.insert(prefix_tokens.begin(), fim_pre_token);
|
||||||
|
prefix_tokens.insert(prefix_tokens.begin(), bos_token); // always add BOS
|
||||||
|
prefix_tokens.insert(prefix_tokens.end(), fim_suf_token);
|
||||||
|
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||||
|
prefix_tokens.push_back(fim_mid_token);
|
||||||
prompt_tokens = prefix_tokens;
|
prompt_tokens = prefix_tokens;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -1620,7 +1637,12 @@ struct llama_server_context
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||||
});
|
});
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
prompt_tokens = new_tokens;
|
|
||||||
|
// Convert new_tokens to server_tokens
|
||||||
|
std::vector<server_tokens> new_prompt_tokens;
|
||||||
|
server_tokens new_server_tokens(new_tokens, mctx != nullptr);
|
||||||
|
new_prompt_tokens.push_back(std::move(new_server_tokens));
|
||||||
|
prompt_tokens = std::move(new_prompt_tokens);
|
||||||
|
|
||||||
slot.num_prompt_tokens = prompt_tokens.size();
|
slot.num_prompt_tokens = prompt_tokens.size();
|
||||||
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
|
||||||
|
@ -1640,10 +1662,17 @@ struct llama_server_context
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
for (auto &token : prompt_tokens)
|
for (auto &token : prompt_tokens)
|
||||||
{
|
{
|
||||||
common_sampler_accept(slot.ctx_sampling, token, false);
|
// Convert server_tokens to llama_token for sampling
|
||||||
|
llama_token tok = token[0]; // Get first token
|
||||||
|
common_sampler_accept(slot.ctx_sampling, tok, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
// Convert server_tokens to llama_tokens for comparison
|
||||||
|
std::vector<llama_token> prompt_llama_tokens;
|
||||||
|
for (const auto &token : prompt_tokens) {
|
||||||
|
prompt_llama_tokens.push_back(token[0]);
|
||||||
|
}
|
||||||
|
slot.n_past = common_part(slot.cache_tokens, prompt_llama_tokens);
|
||||||
|
|
||||||
// the last token of the cache is not in the KV cache until the next call to llama_decode
|
// the last token of the cache is not in the KV cache until the next call to llama_decode
|
||||||
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
|
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
|
||||||
|
@ -1681,7 +1710,12 @@ struct llama_server_context
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.cache_tokens = prompt_tokens;
|
// Convert server_tokens to llama_tokens for cache
|
||||||
|
std::vector<llama_token> cache_llama_tokens;
|
||||||
|
for (const auto &token : prompt_tokens) {
|
||||||
|
cache_llama_tokens.push_back(token[0]);
|
||||||
|
}
|
||||||
|
slot.cache_tokens = cache_llama_tokens;
|
||||||
|
|
||||||
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue