mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(extra-backends): Improvements, adding mamba example (#1618)
* feat(extra-backends): Improvements vllm: add max_tokens, wire up stream event mamba: fixups, adding examples for mamba-chat * examples(mamba-chat): add * docs: update
This commit is contained in:
parent
f3d71f8819
commit
06cd9ef98d
4 changed files with 58 additions and 29 deletions
|
@ -97,12 +97,16 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The predict result.
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
if request.TopP == 0:
|
||||
request.TopP = 0.9
|
||||
|
||||
sampling_params = SamplingParams(temperature=request.Temperature, top_p=request.TopP)
|
||||
max_tokens = 200
|
||||
if request.Tokens > 0:
|
||||
max_tokens = request.Tokens
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
||||
outputs = self.llm.generate([request.Prompt], sampling_params)
|
||||
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
|
@ -110,7 +114,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
if request.Prompt in generated_text:
|
||||
generated_text = generated_text.replace(request.Prompt, "")
|
||||
|
||||
return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8'))
|
||||
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""
|
||||
|
@ -123,11 +127,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
Returns:
|
||||
backend_pb2.Result: The predict stream result.
|
||||
"""
|
||||
# Implement PredictStream RPC
|
||||
#for reply in some_data_generator():
|
||||
# yield reply
|
||||
# Not implemented yet
|
||||
return self.Predict(request, context)
|
||||
yield self.Predict(request, context)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue