feat: Add UseFastTokenizer

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-08-08 01:10:05 +02:00
parent 39805b09e5
commit 3c8fc37c56
10 changed files with 198 additions and 169 deletions

View file

@ -25,7 +25,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Device != "":
device = request.Device
tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=request.UseFastTokenizer)
model = AutoGPTQForCausalLM.from_quantized(request.Model,
model_basename=request.ModelBaseName,
@ -42,14 +42,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Model loaded successfully", success=True)
def Predict(self, request, context):
penalty = 1.0
if request.Penalty != 0.0:
penalty = request.Penalty
tokens = 512
if request.Tokens != 0:
tokens = request.Tokens
top_p = 0.95
if request.TopP != 0.0:
top_p = request.TopP
# Implement Predict RPC
pipeline = TextGenerationPipeline(
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=request.Tokens,
max_new_tokens=tokens,
temperature=request.Temperature,
top_p=request.TopP,
repetition_penalty=request.Penalty,
top_p=top_p,
repetition_penalty=penalty,
)
return backend_pb2.Result(message=bytes(pipeline(request.Prompt)[0]["generated_text"]))