diff --git a/backend/python/common-env/transformers/install.sh b/backend/python/common-env/transformers/install.sh index e268fcc8..8502adde 100644 --- a/backend/python/common-env/transformers/install.sh +++ b/backend/python/common-env/transformers/install.sh @@ -25,7 +25,7 @@ if [ -d "/opt/intel" ]; then # Intel GPU: If the directory exists, we assume we are using the intel image # (no conda env) # https://github.com/intel/intel-extension-for-pytorch/issues/538 - pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed + pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino] fi if [ "$PIP_CACHE_PURGE" = true ] ; then diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index a8702021..04324d9b 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -9,6 +9,7 @@ import signal import sys import os from threading import Thread +import asyncio import time import backend_pb2 @@ -205,17 +206,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): print("Embeddings:", sentence_embeddings, file=sys.stderr) return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0]) - def Predict(self, request, context, streaming=False): - """ - Generates text based on the given prompt and sampling parameters. - - Args: - request: The predict request. - context: The gRPC context. - - Returns: - backend_pb2.Reply: The predict result. - """ + async def _predict(self, request, context, streaming=False): set_seed(request.Seed) if request.TopP == 0: request.TopP = 0.9 @@ -248,21 +239,54 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): thread=Thread(target=self.model.generate, kwargs=config) thread.start() generated_text = "" - for new_text in streamer: - generated_text += new_text - yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) + try: + for new_text in streamer: + generated_text += new_text + yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) + finally: + thread.join() else: - outputs = self.model.generate(inputs["input_ids"], - max_new_tokens=max_tokens, - temperature=request.Temperature, - top_p=request.TopP, - top_k=request.TopK, - do_sample=True, - pad_token=self.tokenizer.eos_token_id) + if XPU and self.OV == False: + outputs = self.model.generate(inputs["input_ids"], + max_new_tokens=max_tokens, + temperature=request.Temperature, + top_p=request.TopP, + top_k=request.TopK, + do_sample=True, + pad_token=self.tokenizer.eos_token_id) + else: + outputs = self.model.generate(inputs["input_ids"], + max_new_tokens=max_tokens, + temperature=request.Temperature, + top_p=request.TopP, + top_k=request.TopK, + do_sample=True, + attention_mask=inputs["attention_mask"], + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id) generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] - return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) - def PredictStream(self, request, context): + if streaming: + return + + yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) + + async def Predict(self, request, context): + """ + Generates text based on the given prompt and sampling parameters. + + Args: + request: The predict request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The predict result. + """ + gen = self._predict(request, context, streaming=False) + res = await gen.__anext__() + return res + + async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results. @@ -273,33 +297,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): Returns: backend_pb2.Result: The predict stream result. """ - iterations = self.Predict(request, context, streaming=True) - for iteration in iterations: - yield iteration + iterations = self._predict(request, context, streaming=True) + try: + async for iteration in iterations: + yield iteration + finally: + await iterations.aclose() - -def serve(address): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) +async def serve(address): + # Start asyncio gRPC server + server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + # Bind the server to the address server.add_insecure_port(address) - server.start() + + # Gracefully shutdown the server on SIGTERM or SIGINT + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler( + sig, lambda: asyncio.ensure_future(server.stop(5)) + ) + + # Start the server + await server.start() print("Server started. Listening on: " + address, file=sys.stderr) - - # Define the signal handler function - def signal_handler(sig, frame): - print("Received termination signal. Shutting down...") - server.stop(0) - sys.exit(0) - - # Set the signal handlers for SIGINT and SIGTERM - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - while True: - time.sleep(_ONE_DAY_IN_SECONDS) - except KeyboardInterrupt: - server.stop(0) + # Wait for the server to be terminated + await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") @@ -308,4 +332,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - serve(args.addr) + asyncio.run(serve(args.addr)) \ No newline at end of file