diff --git a/backend/python/vllm/backend_vllm.py b/backend/python/vllm/backend_vllm.py index d5b8b51f..8f8c4ee0 100644 --- a/backend/python/vllm/backend_vllm.py +++ b/backend/python/vllm/backend_vllm.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +import asyncio from concurrent import futures -import time import argparse import signal import sys @@ -10,7 +10,10 @@ import backend_pb2 import backend_pb2_grpc import grpc -from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -79,16 +82,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Result: The load model result. """ + engine_args = AsyncEngineArgs( + model=request.Model, + ) + + if request.Quantization != "": + engine_args.quantization = request.Quantization + try: - if request.Quantization != "": - self.llm = LLM(model=request.Model, quantization=request.Quantization) - else: - self.llm = LLM(model=request.Model) + self.llm = AsyncLLMEngine.from_engine_args(engine_args) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) - def Predict(self, request, context): + async def Predict(self, request, context): """ Generates text based on the given prompt and sampling parameters. @@ -99,24 +106,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Reply: The predict result. """ - if request.TopP == 0: - request.TopP = 0.9 + gen = self._predict(request, context, streaming=False) + res = await gen.__anext__() + return res - max_tokens = 200 - if request.Tokens > 0: - max_tokens = request.Tokens - - sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP) - outputs = self.llm.generate([request.Prompt], sampling_params) - - generated_text = outputs[0].outputs[0].text - # Remove prompt from response if present - if request.Prompt in generated_text: - generated_text = generated_text.replace(request.Prompt, "") - - return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) - - def PredictStream(self, request, context): + async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results. @@ -127,30 +121,84 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Result: The predict stream result. """ - yield self.Predict(request, context) + iterations = self._predict(request, context, streaming=True) + try: + async for iteration in iterations: + yield iteration + finally: + await iterations.aclose() -def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + async def _predict(self, request, context, streaming=False): + + # Build sampling parameters + sampling_params = SamplingParams(top_p=0.9, max_tokens=200) + if request.TopP != 0: + sampling_params.top_p = request.TopP + if request.Tokens > 0: + sampling_params.max_tokens = request.Tokens + if request.Temperature != 0: + sampling_params.temperature = request.Temperature + if request.TopK != 0: + sampling_params.top_k = request.TopK + if request.PresencePenalty != 0: + sampling_params.presence_penalty = request.PresencePenalty + if request.FrequencyPenalty != 0: + sampling_params.frequency_penalty = request.FrequencyPenalty + if request.StopPrompts: + sampling_params.stop = request.StopPrompts + if request.IgnoreEOS: + sampling_params.ignore_eos = request.IgnoreEOS + if request.Seed != 0: + sampling_params.seed = request.Seed + + # Generate text + request_id = random_uuid() + outputs = self.llm.generate(request.Prompt, sampling_params, request_id) + + # Stream the results + generated_text = "" + try: + async for request_output in outputs: + iteration_text = request_output.outputs[0].text + + if streaming: + # Remove text already sent as vllm concatenates the text from previous yields + delta_iteration_text = iteration_text.removeprefix(generated_text) + # Send the partial result + yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8')) + + # Keep track of text generated + generated_text = iteration_text + finally: + await outputs.aclose() + + # If streaming, we already sent everything + if streaming: + return + + # Sending the final generated text + yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) + +async def serve(address): + # Start asyncio gRPC server + server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + # Bind the server to the address server.add_insecure_port(address) - server.start() + + # Gracefully shutdown the server on SIGTERM or SIGINT + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler( + sig, lambda: asyncio.ensure_future(server.stop(5)) + ) + + # Start the server + await server.start() print("Server started. Listening on: " + address, file=sys.stderr) - - # Define the signal handler function - def signal_handler(sig, frame): - print("Received termination signal. Shutting down...") - server.stop(0) - sys.exit(0) - - # Set the signal handlers for SIGINT and SIGTERM - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - while True: - time.sleep(_ONE_DAY_IN_SECONDS) - except KeyboardInterrupt: - server.stop(0) + # Wait for the server to be terminated + await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") @@ -159,4 +207,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - serve(args.addr) + asyncio.run(serve(args.addr)) \ No newline at end of file