mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(transformers): various enhancements to the transformers backend (#2468)
update transformers *Handle Temperature = 0 as greedy search *Handle custom works as stop words *Implement KV cache *Phi 3 no more requires trust_remote_code: true
This commit is contained in:
parent
5ddaa19914
commit
4a239a4bff
1 changed files with 36 additions and 23 deletions
59
backend/python/transformers/backend.py
Executable file → Normal file
59
backend/python/transformers/backend.py
Executable file → Normal file
|
@ -22,9 +22,9 @@ import torch.cuda
|
||||||
|
|
||||||
XPU=os.environ.get("XPU", "0") == "1"
|
XPU=os.environ.get("XPU", "0") == "1"
|
||||||
if XPU:
|
if XPU:
|
||||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer
|
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||||
else:
|
else:
|
||||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer
|
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
@ -246,28 +246,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
|
||||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||||
# print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
|
||||||
# print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
|
||||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
||||||
|
|
||||||
async def _predict(self, request, context, streaming=False):
|
async def _predict(self, request, context, streaming=False):
|
||||||
set_seed(request.Seed)
|
set_seed(request.Seed)
|
||||||
if request.TopP == 0:
|
if request.TopP < 0 or request.TopP > 1:
|
||||||
request.TopP = 0.9
|
request.TopP = 1
|
||||||
|
|
||||||
if request.TopK == 0:
|
if request.TopK <= 0:
|
||||||
request.TopK = 40
|
request.TopK = 50
|
||||||
|
|
||||||
|
if request.Temperature > 0 :
|
||||||
|
sample=True
|
||||||
|
else:
|
||||||
|
sample=False
|
||||||
|
request.TopP == None
|
||||||
|
request.TopK == None
|
||||||
|
request.Temperature == None
|
||||||
|
|
||||||
prompt = request.Prompt
|
prompt = request.Prompt
|
||||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||||
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
|
||||||
eos_token_id = self.tokenizer.eos_token_id
|
|
||||||
if request.StopPrompts:
|
|
||||||
eos_token_id = []
|
|
||||||
for word in request.StopPrompts:
|
|
||||||
eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))
|
|
||||||
|
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
if request.Tokens > 0:
|
if request.Tokens > 0:
|
||||||
|
@ -281,6 +281,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
inputs = inputs.to("xpu")
|
inputs = inputs.to("xpu")
|
||||||
streaming = False
|
streaming = False
|
||||||
|
|
||||||
|
criteria=[]
|
||||||
|
if request.StopPrompts:
|
||||||
|
criteria = StoppingCriteriaList(
|
||||||
|
[
|
||||||
|
StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
streamer=TextIteratorStreamer(self.tokenizer,
|
streamer=TextIteratorStreamer(self.tokenizer,
|
||||||
skip_prompt=True,
|
skip_prompt=True,
|
||||||
|
@ -290,11 +298,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
temperature=request.Temperature,
|
temperature=request.Temperature,
|
||||||
top_p=request.TopP,
|
top_p=request.TopP,
|
||||||
top_k=request.TopK,
|
top_k=request.TopK,
|
||||||
do_sample=True,
|
do_sample=sample,
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
pad_token_id=self.tokenizer.eos_token_id,
|
pad_token_id=self.tokenizer.eos_token_id,
|
||||||
streamer=streamer)
|
streamer=streamer,
|
||||||
|
stopping_criteria=criteria,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
thread=Thread(target=self.model.generate, kwargs=config)
|
thread=Thread(target=self.model.generate, kwargs=config)
|
||||||
thread.start()
|
thread.start()
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
|
@ -311,18 +322,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
temperature=request.Temperature,
|
temperature=request.Temperature,
|
||||||
top_p=request.TopP,
|
top_p=request.TopP,
|
||||||
top_k=request.TopK,
|
top_k=request.TopK,
|
||||||
do_sample=True,
|
do_sample=sample,
|
||||||
pad_token=self.tokenizer.eos_token_id)
|
pad_token=self.tokenizer.eos_token_id)
|
||||||
else:
|
else:
|
||||||
outputs = self.model.generate(inputs["input_ids"],
|
outputs = self.model.generate(**inputs,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
temperature=request.Temperature,
|
temperature=request.Temperature,
|
||||||
top_p=request.TopP,
|
top_p=request.TopP,
|
||||||
top_k=request.TopK,
|
top_k=request.TopK,
|
||||||
do_sample=True,
|
do_sample=sample,
|
||||||
attention_mask=inputs["attention_mask"],
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
eos_token_id=eos_token_id,
|
pad_token_id=self.tokenizer.eos_token_id,
|
||||||
pad_token_id=self.tokenizer.eos_token_id)
|
stopping_criteria=criteria,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
|
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue