feat(transformers): various enhancements to the transformers backend (#2468)

update transformers

*Handle Temperature = 0 as greedy search
*Handle custom works as stop words
*Implement KV cache
*Phi 3 no more requires trust_remote_code: true
This commit is contained in:
fakezeta 2024-06-03 08:52:55 +02:00 committed by GitHub
parent 5ddaa19914
commit 4a239a4bff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

59
backend/python/transformers/backend.py Executable file → Normal file
View file

@ -22,9 +22,9 @@ import torch.cuda
XPU=os.environ.get("XPU", "0") == "1" XPU=os.environ.get("XPU", "0") == "1"
if XPU: if XPU:
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
else: else:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -246,28 +246,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
# print("Embeddings:", sentence_embeddings, file=sys.stderr)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0]) return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
async def _predict(self, request, context, streaming=False): async def _predict(self, request, context, streaming=False):
set_seed(request.Seed) set_seed(request.Seed)
if request.TopP == 0: if request.TopP < 0 or request.TopP > 1:
request.TopP = 0.9 request.TopP = 1
if request.TopK == 0: if request.TopK <= 0:
request.TopK = 40 request.TopK = 50
if request.Temperature > 0 :
sample=True
else:
sample=False
request.TopP == None
request.TopK == None
request.Temperature == None
prompt = request.Prompt prompt = request.Prompt
if not request.Prompt and request.UseTokenizerTemplate and request.Messages: if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
eos_token_id = self.tokenizer.eos_token_id
if request.StopPrompts:
eos_token_id = []
for word in request.StopPrompts:
eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))
inputs = self.tokenizer(prompt, return_tensors="pt") inputs = self.tokenizer(prompt, return_tensors="pt")
if request.Tokens > 0: if request.Tokens > 0:
@ -281,6 +281,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
inputs = inputs.to("xpu") inputs = inputs.to("xpu")
streaming = False streaming = False
criteria=[]
if request.StopPrompts:
criteria = StoppingCriteriaList(
[
StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts),
]
)
if streaming: if streaming:
streamer=TextIteratorStreamer(self.tokenizer, streamer=TextIteratorStreamer(self.tokenizer,
skip_prompt=True, skip_prompt=True,
@ -290,11 +298,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
temperature=request.Temperature, temperature=request.Temperature,
top_p=request.TopP, top_p=request.TopP,
top_k=request.TopK, top_k=request.TopK,
do_sample=True, do_sample=sample,
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
eos_token_id=eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id,
streamer=streamer) streamer=streamer,
stopping_criteria=criteria,
use_cache=True,
)
thread=Thread(target=self.model.generate, kwargs=config) thread=Thread(target=self.model.generate, kwargs=config)
thread.start() thread.start()
generated_text = "" generated_text = ""
@ -311,18 +322,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
temperature=request.Temperature, temperature=request.Temperature,
top_p=request.TopP, top_p=request.TopP,
top_k=request.TopK, top_k=request.TopK,
do_sample=True, do_sample=sample,
pad_token=self.tokenizer.eos_token_id) pad_token=self.tokenizer.eos_token_id)
else: else:
outputs = self.model.generate(inputs["input_ids"], outputs = self.model.generate(**inputs,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
temperature=request.Temperature, temperature=request.Temperature,
top_p=request.TopP, top_p=request.TopP,
top_k=request.TopK, top_k=request.TopK,
do_sample=True, do_sample=sample,
attention_mask=inputs["attention_mask"], eos_token_id=self.tokenizer.eos_token_id,
eos_token_id=eos_token_id, pad_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id) stopping_criteria=criteria,
use_cache=True,
)
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
if streaming: if streaming: