feat(extra-backends): Improvements, adding mamba example (#1618)

* feat(extra-backends): Improvements

vllm: add max_tokens, wire up stream event
mamba: fixups, adding examples for mamba-chat

* examples(mamba-chat): add

* docs: update
This commit is contained in:
Ettore Di Giacinto 2024-01-20 17:56:08 +01:00 committed by GitHub
parent f3d71f8819
commit 06cd9ef98d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 58 additions and 29 deletions

View file

@ -117,9 +117,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Tokens == 0:
max_tokens = 2000
encoded_input = self.tokenizer(request.Prompt)
out = self.model.generate(input_ids=encoded_input["input_ids"], max_length=max_tokens, temperature=request.Temperratur,
# encoded_input = self.tokenizer(request.Prompt)
tokens = self.tokenizer(request.Prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device="cuda")
out = self.model.generate(input_ids=input_ids, max_length=max_tokens, temperature=request.Temperature,
top_p=request.TopP, eos_token_id=self.tokenizer.eos_token_id)
decoded = self.tokenizer.batch_decode(out)
@ -130,7 +131,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")
return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8'))
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
def PredictStream(self, request, context):
"""
@ -143,11 +144,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Result: The predict stream result.
"""
# Implement PredictStream RPC
#for reply in some_data_generator():
# yield reply
# Not implemented yet
return self.Predict(request, context)
yield self.Predict(request, context)
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))

View file

@ -97,12 +97,16 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
context: The gRPC context.
Returns:
backend_pb2.Result: The predict result.
backend_pb2.Reply: The predict result.
"""
if request.TopP == 0:
request.TopP = 0.9
sampling_params = SamplingParams(temperature=request.Temperature, top_p=request.TopP)
max_tokens = 200
if request.Tokens > 0:
max_tokens = request.Tokens
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
outputs = self.llm.generate([request.Prompt], sampling_params)
generated_text = outputs[0].outputs[0].text
@ -110,7 +114,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")
return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8'))
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
def PredictStream(self, request, context):
"""
@ -123,11 +127,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Result: The predict stream result.
"""
# Implement PredictStream RPC
#for reply in some_data_generator():
# yield reply
# Not implemented yet
return self.Predict(request, context)
yield self.Predict(request, context)
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))