From 0135e1e3b90d361b337f84eb76ee9d784ea20c40 Mon Sep 17 00:00:00 2001 From: Ludovic Leroux Date: Sat, 24 Feb 2024 05:48:45 -0500 Subject: [PATCH] fix: vllm - use AsyncLLMEngine to allow true streaming mode (#1749) * fix: use vllm AsyncLLMEngine to bring true stream Current vLLM implementation uses the LLMEngine, which was designed for offline batch inference, which results in the streaming mode outputing all blobs at once at the end of the inference. This PR reworks the gRPC server to use asyncio and gRPC.aio, in combination with vLLM's AsyncLLMEngine to bring true stream mode. This PR also passes more parameters to vLLM during inference (presence_penalty, frequency_penalty, stop, ignore_eos, seed, ...). * Remove unused import --- backend/python/vllm/backend_vllm.py | 138 +++++++++++++++++++--------- 1 file changed, 93 insertions(+), 45 deletions(-) diff --git a/backend/python/vllm/backend_vllm.py b/backend/python/vllm/backend_vllm.py index d5b8b51f..8f8c4ee0 100644 --- a/backend/python/vllm/backend_vllm.py +++ b/backend/python/vllm/backend_vllm.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +import asyncio from concurrent import futures -import time import argparse import signal import sys @@ -10,7 +10,10 @@ import backend_pb2 import backend_pb2_grpc import grpc -from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -79,16 +82,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Result: The load model result. """ + engine_args = AsyncEngineArgs( + model=request.Model, + ) + + if request.Quantization != "": + engine_args.quantization = request.Quantization + try: - if request.Quantization != "": - self.llm = LLM(model=request.Model, quantization=request.Quantization) - else: - self.llm = LLM(model=request.Model) + self.llm = AsyncLLMEngine.from_engine_args(engine_args) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) - def Predict(self, request, context): + async def Predict(self, request, context): """ Generates text based on the given prompt and sampling parameters. @@ -99,24 +106,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Reply: The predict result. """ - if request.TopP == 0: - request.TopP = 0.9 + gen = self._predict(request, context, streaming=False) + res = await gen.__anext__() + return res - max_tokens = 200 - if request.Tokens > 0: - max_tokens = request.Tokens - - sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP) - outputs = self.llm.generate([request.Prompt], sampling_params) - - generated_text = outputs[0].outputs[0].text - # Remove prompt from response if present - if request.Prompt in generated_text: - generated_text = generated_text.replace(request.Prompt, "") - - return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) - - def PredictStream(self, request, context): + async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results. @@ -127,30 +121,84 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Result: The predict stream result. """ - yield self.Predict(request, context) + iterations = self._predict(request, context, streaming=True) + try: + async for iteration in iterations: + yield iteration + finally: + await iterations.aclose() -def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + async def _predict(self, request, context, streaming=False): + + # Build sampling parameters + sampling_params = SamplingParams(top_p=0.9, max_tokens=200) + if request.TopP != 0: + sampling_params.top_p = request.TopP + if request.Tokens > 0: + sampling_params.max_tokens = request.Tokens + if request.Temperature != 0: + sampling_params.temperature = request.Temperature + if request.TopK != 0: + sampling_params.top_k = request.TopK + if request.PresencePenalty != 0: + sampling_params.presence_penalty = request.PresencePenalty + if request.FrequencyPenalty != 0: + sampling_params.frequency_penalty = request.FrequencyPenalty + if request.StopPrompts: + sampling_params.stop = request.StopPrompts + if request.IgnoreEOS: + sampling_params.ignore_eos = request.IgnoreEOS + if request.Seed != 0: + sampling_params.seed = request.Seed + + # Generate text + request_id = random_uuid() + outputs = self.llm.generate(request.Prompt, sampling_params, request_id) + + # Stream the results + generated_text = "" + try: + async for request_output in outputs: + iteration_text = request_output.outputs[0].text + + if streaming: + # Remove text already sent as vllm concatenates the text from previous yields + delta_iteration_text = iteration_text.removeprefix(generated_text) + # Send the partial result + yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8')) + + # Keep track of text generated + generated_text = iteration_text + finally: + await outputs.aclose() + + # If streaming, we already sent everything + if streaming: + return + + # Sending the final generated text + yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) + +async def serve(address): + # Start asyncio gRPC server + server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + # Bind the server to the address server.add_insecure_port(address) - server.start() + + # Gracefully shutdown the server on SIGTERM or SIGINT + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler( + sig, lambda: asyncio.ensure_future(server.stop(5)) + ) + + # Start the server + await server.start() print("Server started. Listening on: " + address, file=sys.stderr) - - # Define the signal handler function - def signal_handler(sig, frame): - print("Received termination signal. Shutting down...") - server.stop(0) - sys.exit(0) - - # Set the signal handlers for SIGINT and SIGTERM - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - while True: - time.sleep(_ONE_DAY_IN_SECONDS) - except KeyboardInterrupt: - server.stop(0) + # Wait for the server to be terminated + await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") @@ -159,4 +207,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - serve(args.addr) + asyncio.run(serve(args.addr)) \ No newline at end of file