feat: use tokenizer.apply_chat_template() in vLLM (#1990)

Use tokenizer.apply_chat_template() in vLLM

Signed-off-by: Ludovic LEROUX <ludovic@inpher.io>
This commit is contained in:
Ludovic Leroux 2024-04-11 13:20:22 -04:00 committed by GitHub
parent cbda06fb96
commit 12c0d9443e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 3088 additions and 989 deletions

View file

@ -14,6 +14,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.transformers_utils.tokenizer import get_tokenizer
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -71,7 +72,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
async def LoadModel(self, request, context):
"""
Loads a language model.
@ -103,6 +104,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
try:
engine_model_config = await self.llm.get_model_config()
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left",
)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
async def Predict(self, request, context):
@ -161,9 +174,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Seed != 0:
sampling_params.seed = request.Seed
prompt = request.Prompt
# If tokenizer template is enabled and messages are provided instead of prompt apply the tokenizer template
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
# Generate text
request_id = random_uuid()
outputs = self.llm.generate(request.Prompt, sampling_params, request_id)
outputs = self.llm.generate(prompt, sampling_params, request_id)
# Stream the results
generated_text = ""