feat: use tokenizer.apply_chat_template() in vLLM (#1990)

Use tokenizer.apply_chat_template() in vLLM

Signed-off-by: Ludovic LEROUX <ludovic@inpher.io>
This commit is contained in:
Ludovic Leroux 2024-04-11 13:20:22 -04:00 committed by GitHub
parent cbda06fb96
commit 12c0d9443e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 3088 additions and 989 deletions

File diff suppressed because one or more lines are too long

View file

@ -64,6 +64,26 @@ class BackendStub(object):
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)
self.StoresSet = channel.unary_unary(
'/backend.Backend/StoresSet',
request_serializer=backend__pb2.StoresSetOptions.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.StoresDelete = channel.unary_unary(
'/backend.Backend/StoresDelete',
request_serializer=backend__pb2.StoresDeleteOptions.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.StoresGet = channel.unary_unary(
'/backend.Backend/StoresGet',
request_serializer=backend__pb2.StoresGetOptions.SerializeToString,
response_deserializer=backend__pb2.StoresGetResult.FromString,
)
self.StoresFind = channel.unary_unary(
'/backend.Backend/StoresFind',
request_serializer=backend__pb2.StoresFindOptions.SerializeToString,
response_deserializer=backend__pb2.StoresFindResult.FromString,
)
class BackendServicer(object):
@ -129,6 +149,30 @@ class BackendServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StoresSet(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StoresDelete(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StoresGet(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StoresFind(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -182,6 +226,26 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
'StoresSet': grpc.unary_unary_rpc_method_handler(
servicer.StoresSet,
request_deserializer=backend__pb2.StoresSetOptions.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'StoresDelete': grpc.unary_unary_rpc_method_handler(
servicer.StoresDelete,
request_deserializer=backend__pb2.StoresDeleteOptions.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'StoresGet': grpc.unary_unary_rpc_method_handler(
servicer.StoresGet,
request_deserializer=backend__pb2.StoresGetOptions.FromString,
response_serializer=backend__pb2.StoresGetResult.SerializeToString,
),
'StoresFind': grpc.unary_unary_rpc_method_handler(
servicer.StoresFind,
request_deserializer=backend__pb2.StoresFindOptions.FromString,
response_serializer=backend__pb2.StoresFindResult.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
@ -361,3 +425,71 @@ class Backend(object):
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def StoresSet(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/StoresSet',
backend__pb2.StoresSetOptions.SerializeToString,
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def StoresDelete(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/StoresDelete',
backend__pb2.StoresDeleteOptions.SerializeToString,
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def StoresGet(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/StoresGet',
backend__pb2.StoresGetOptions.SerializeToString,
backend__pb2.StoresGetResult.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def StoresFind(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/StoresFind',
backend__pb2.StoresFindOptions.SerializeToString,
backend__pb2.StoresFindResult.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View file

@ -14,6 +14,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.transformers_utils.tokenizer import get_tokenizer
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -71,7 +72,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
async def LoadModel(self, request, context):
"""
Loads a language model.
@ -103,6 +104,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
try:
engine_model_config = await self.llm.get_model_config()
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left",
)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
async def Predict(self, request, context):
@ -161,9 +174,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Seed != 0:
sampling_params.seed = request.Seed
prompt = request.Prompt
# If tokenizer template is enabled and messages are provided instead of prompt apply the tokenizer template
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
# Generate text
request_id = random_uuid()
outputs = self.llm.generate(request.Prompt, sampling_params, request_id)
outputs = self.llm.generate(prompt, sampling_params, request_id)
# Stream the results
generated_text = ""