fix: exllama2 backend (#1484)

Signed-off-by: Sertac Ozercan <sozercan@gmail.com>
This commit is contained in:
Sertaç Özercan 2023-12-24 00:32:12 -08:00 committed by GitHub
parent eaa899df63
commit 6597881854
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,7 +7,8 @@ import backend_pb2_grpc
import argparse import argparse
import signal import signal
import sys import sys
import os, glob import os
import glob
from pathlib import Path from pathlib import Path
import torch import torch
@ -21,7 +22,7 @@ from exllamav2.generator import (
) )
from exllamav2 import( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Config, ExLlamaV2Config,
ExLlamaV2Cache, ExLlamaV2Cache,
@ -40,6 +41,7 @@ MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context): def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8')) return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context): def LoadModel(self, request, context):
try: try:
model_directory = request.ModelFile model_directory = request.ModelFile
@ -50,7 +52,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model = ExLlamaV2(config) model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True) cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache) model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config) tokenizer = ExLlamaV2Tokenizer(config)
@ -59,7 +61,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
self.generator= generator self.generator = generator
generator.warmup() generator.warmup()
self.model = model self.model = model
@ -85,17 +87,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Tokens != 0: if request.Tokens != 0:
tokens = request.Tokens tokens = request.Tokens
output = self.generator.generate_simple(request.Prompt, settings, tokens, seed = self.seed) output = self.generator.generate_simple(
request.Prompt, settings, tokens)
# Remove prompt from response if present # Remove prompt from response if present
if request.Prompt in output: if request.Prompt in output:
output = output.replace(request.Prompt, "") output = output.replace(request.Prompt, "")
return backend_pb2.Result(message=bytes(t, encoding='utf-8')) return backend_pb2.Result(message=bytes(output, encoding='utf-8'))
def PredictStream(self, request, context): def PredictStream(self, request, context):
# Implement PredictStream RPC # Implement PredictStream RPC
#for reply in some_data_generator(): # for reply in some_data_generator():
# yield reply # yield reply
# Not implemented yet # Not implemented yet
return self.Predict(request, context) return self.Predict(request, context)
@ -124,6 +127,7 @@ def serve(address):
except KeyboardInterrupt: except KeyboardInterrupt:
server.stop(0) server.stop(0)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.") parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument( parser.add_argument(