mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 18:45:00 +00:00
fix: vllm - use AsyncLLMEngine to allow true streaming mode (#1749)
* fix: use vllm AsyncLLMEngine to bring true stream Current vLLM implementation uses the LLMEngine, which was designed for offline batch inference, which results in the streaming mode outputing all blobs at once at the end of the inference. This PR reworks the gRPC server to use asyncio and gRPC.aio, in combination with vLLM's AsyncLLMEngine to bring true stream mode. This PR also passes more parameters to vLLM during inference (presence_penalty, frequency_penalty, stop, ignore_eos, seed, ...). * Remove unused import
This commit is contained in:
parent
ff88c390bb
commit
0135e1e3b9
1 changed files with 93 additions and 45 deletions
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import asyncio
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
@ -10,7 +10,10 @@ import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
from vllm import LLM, SamplingParams
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
|
@ -79,16 +82,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Result: The load model result.
|
backend_pb2.Result: The load model result.
|
||||||
"""
|
"""
|
||||||
try:
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=request.Model,
|
||||||
|
)
|
||||||
|
|
||||||
if request.Quantization != "":
|
if request.Quantization != "":
|
||||||
self.llm = LLM(model=request.Model, quantization=request.Quantization)
|
engine_args.quantization = request.Quantization
|
||||||
else:
|
|
||||||
self.llm = LLM(model=request.Model)
|
try:
|
||||||
|
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
def Predict(self, request, context):
|
async def Predict(self, request, context):
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters.
|
Generates text based on the given prompt and sampling parameters.
|
||||||
|
|
||||||
|
@ -99,24 +106,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Reply: The predict result.
|
backend_pb2.Reply: The predict result.
|
||||||
"""
|
"""
|
||||||
if request.TopP == 0:
|
gen = self._predict(request, context, streaming=False)
|
||||||
request.TopP = 0.9
|
res = await gen.__anext__()
|
||||||
|
return res
|
||||||
|
|
||||||
max_tokens = 200
|
async def PredictStream(self, request, context):
|
||||||
if request.Tokens > 0:
|
|
||||||
max_tokens = request.Tokens
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
|
||||||
outputs = self.llm.generate([request.Prompt], sampling_params)
|
|
||||||
|
|
||||||
generated_text = outputs[0].outputs[0].text
|
|
||||||
# Remove prompt from response if present
|
|
||||||
if request.Prompt in generated_text:
|
|
||||||
generated_text = generated_text.replace(request.Prompt, "")
|
|
||||||
|
|
||||||
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||||
|
|
||||||
|
@ -127,30 +121,84 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Result: The predict stream result.
|
backend_pb2.Result: The predict stream result.
|
||||||
"""
|
"""
|
||||||
yield self.Predict(request, context)
|
iterations = self._predict(request, context, streaming=True)
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
async for iteration in iterations:
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
yield iteration
|
||||||
except KeyboardInterrupt:
|
finally:
|
||||||
server.stop(0)
|
await iterations.aclose()
|
||||||
|
|
||||||
|
async def _predict(self, request, context, streaming=False):
|
||||||
|
|
||||||
|
# Build sampling parameters
|
||||||
|
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
||||||
|
if request.TopP != 0:
|
||||||
|
sampling_params.top_p = request.TopP
|
||||||
|
if request.Tokens > 0:
|
||||||
|
sampling_params.max_tokens = request.Tokens
|
||||||
|
if request.Temperature != 0:
|
||||||
|
sampling_params.temperature = request.Temperature
|
||||||
|
if request.TopK != 0:
|
||||||
|
sampling_params.top_k = request.TopK
|
||||||
|
if request.PresencePenalty != 0:
|
||||||
|
sampling_params.presence_penalty = request.PresencePenalty
|
||||||
|
if request.FrequencyPenalty != 0:
|
||||||
|
sampling_params.frequency_penalty = request.FrequencyPenalty
|
||||||
|
if request.StopPrompts:
|
||||||
|
sampling_params.stop = request.StopPrompts
|
||||||
|
if request.IgnoreEOS:
|
||||||
|
sampling_params.ignore_eos = request.IgnoreEOS
|
||||||
|
if request.Seed != 0:
|
||||||
|
sampling_params.seed = request.Seed
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
request_id = random_uuid()
|
||||||
|
outputs = self.llm.generate(request.Prompt, sampling_params, request_id)
|
||||||
|
|
||||||
|
# Stream the results
|
||||||
|
generated_text = ""
|
||||||
|
try:
|
||||||
|
async for request_output in outputs:
|
||||||
|
iteration_text = request_output.outputs[0].text
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# Remove text already sent as vllm concatenates the text from previous yields
|
||||||
|
delta_iteration_text = iteration_text.removeprefix(generated_text)
|
||||||
|
# Send the partial result
|
||||||
|
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
|
||||||
|
|
||||||
|
# Keep track of text generated
|
||||||
|
generated_text = iteration_text
|
||||||
|
finally:
|
||||||
|
await outputs.aclose()
|
||||||
|
|
||||||
|
# If streaming, we already sent everything
|
||||||
|
if streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sending the final generated text
|
||||||
|
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||||
|
|
||||||
|
async def serve(address):
|
||||||
|
# Start asyncio gRPC server
|
||||||
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
|
# Add the servicer to the server
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
# Bind the server to the address
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
|
||||||
|
# Gracefully shutdown the server on SIGTERM or SIGINT
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(
|
||||||
|
sig, lambda: asyncio.ensure_future(server.stop(5))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the server
|
||||||
|
await server.start()
|
||||||
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
# Wait for the server to be terminated
|
||||||
|
await server.wait_for_termination()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
|
@ -159,4 +207,4 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
serve(args.addr)
|
asyncio.run(serve(args.addr))
|
Loading…
Add table
Add a link
Reference in a new issue