fix: exllama2 backend (#1484)

Signed-off-by: Sertac Ozercan <sozercan@gmail.com>
This commit is contained in:
Sertaç Özercan 2023-12-24 00:32:12 -08:00 committed by GitHub
parent eaa899df63
commit 6597881854
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,7 +7,8 @@ import backend_pb2_grpc
import argparse import argparse
import signal import signal
import sys import sys
import os, glob import os
import glob
from pathlib import Path from pathlib import Path
import torch import torch
@ -40,6 +41,7 @@ MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context): def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8')) return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context): def LoadModel(self, request, context):
try: try:
model_directory = request.ModelFile model_directory = request.ModelFile
@ -85,13 +87,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Tokens != 0: if request.Tokens != 0:
tokens = request.Tokens tokens = request.Tokens
output = self.generator.generate_simple(request.Prompt, settings, tokens, seed = self.seed) output = self.generator.generate_simple(
request.Prompt, settings, tokens)
# Remove prompt from response if present # Remove prompt from response if present
if request.Prompt in output: if request.Prompt in output:
output = output.replace(request.Prompt, "") output = output.replace(request.Prompt, "")
return backend_pb2.Result(message=bytes(t, encoding='utf-8')) return backend_pb2.Result(message=bytes(output, encoding='utf-8'))
def PredictStream(self, request, context): def PredictStream(self, request, context):
# Implement PredictStream RPC # Implement PredictStream RPC
@ -124,6 +127,7 @@ def serve(address):
except KeyboardInterrupt: except KeyboardInterrupt:
server.stop(0) server.stop(0)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.") parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument( parser.add_argument(