diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 7de1070c..c57f4070 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -1579,15 +1579,32 @@ struct llama_server_context auto suffix_tokens = tokenize(slot.data, slot.params.input_suffix, false); const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { + if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0][0] == space_token) { suffix_tokens.erase(suffix_tokens.begin()); } - prefix_tokens.insert(prefix_tokens.begin(), llama_vocab_fim_pre(vocab)); - prefix_tokens.insert(prefix_tokens.begin(), llama_vocab_bos(vocab)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_vocab_fim_suf(vocab)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - prefix_tokens.push_back(llama_vocab_fim_mid(vocab)); + // Create llama_tokens vectors for the special tokens + llama_tokens fim_pre_tokens; + fim_pre_tokens.push_back(llama_vocab_fim_pre(vocab)); + llama_tokens bos_tokens; + bos_tokens.push_back(llama_vocab_bos(vocab)); + llama_tokens fim_suf_tokens; + fim_suf_tokens.push_back(llama_vocab_fim_suf(vocab)); + llama_tokens fim_mid_tokens; + fim_mid_tokens.push_back(llama_vocab_fim_mid(vocab)); + + // Create server_tokens objects + server_tokens fim_pre_token(fim_pre_tokens, mctx != nullptr); + server_tokens bos_token(bos_tokens, mctx != nullptr); + server_tokens fim_suf_token(fim_suf_tokens, mctx != nullptr); + server_tokens fim_mid_token(fim_mid_tokens, mctx != nullptr); + + // Insert tokens in the correct order + prefix_tokens.insert(prefix_tokens.begin(), fim_pre_token); + prefix_tokens.insert(prefix_tokens.begin(), bos_token); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), fim_suf_token); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(fim_mid_token); prompt_tokens = prefix_tokens; } else @@ -1620,7 +1637,12 @@ struct llama_server_context {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, }); slot.truncated = true; - prompt_tokens = new_tokens; + + // Convert new_tokens to server_tokens + std::vector new_prompt_tokens; + server_tokens new_server_tokens(new_tokens, mctx != nullptr); + new_prompt_tokens.push_back(std::move(new_server_tokens)); + prompt_tokens = std::move(new_prompt_tokens); slot.num_prompt_tokens = prompt_tokens.size(); GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); @@ -1640,10 +1662,17 @@ struct llama_server_context // push the prompt into the sampling context (do not apply grammar) for (auto &token : prompt_tokens) { - common_sampler_accept(slot.ctx_sampling, token, false); + // Convert server_tokens to llama_token for sampling + llama_token tok = token[0]; // Get first token + common_sampler_accept(slot.ctx_sampling, tok, false); } - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + // Convert server_tokens to llama_tokens for comparison + std::vector prompt_llama_tokens; + for (const auto &token : prompt_tokens) { + prompt_llama_tokens.push_back(token[0]); + } + slot.n_past = common_part(slot.cache_tokens, prompt_llama_tokens); // the last token of the cache is not in the KV cache until the next call to llama_decode // (it was sampled, pushed into the "cache_tokens", but not yet put in the context) @@ -1681,7 +1710,12 @@ struct llama_server_context }); } - slot.cache_tokens = prompt_tokens; + // Convert server_tokens to llama_tokens for cache + std::vector cache_llama_tokens; + for (const auto &token : prompt_tokens) { + cache_llama_tokens.push_back(token[0]); + } + slot.cache_tokens = cache_llama_tokens; if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) {