mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-29 14:14:59 +00:00
merge sentencetransformers
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
ee7904f170
commit
981310c94f
6 changed files with 34 additions and 14 deletions
|
@ -25,6 +25,8 @@ from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreame
|
|||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
from scipy.io import wavfile
|
||||
import outetts
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
|
@ -88,6 +90,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
self.CUDA = torch.cuda.is_available()
|
||||
self.OV=False
|
||||
self.OuteTTS=False
|
||||
self.SentenceTransformer = False
|
||||
|
||||
device_map="cpu"
|
||||
|
||||
|
@ -235,6 +238,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
|
||||
else:
|
||||
self.speaker = self.interface.load_default_speaker(name=SPEAKER)
|
||||
elif request.Type == "SentenceTransformer":
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
else:
|
||||
print("Automodel", file=sys.stderr)
|
||||
self.model = AutoModel.from_pretrained(model_name,
|
||||
|
@ -286,6 +292,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
max_length = 512
|
||||
if request.Tokens != 0:
|
||||
max_length = request.Tokens
|
||||
|
||||
embeds = None
|
||||
|
||||
if self.SentenceTransformer:
|
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
embeds = self.model.encode(request.Embeddings)
|
||||
else:
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
|
||||
# Create word embeddings
|
||||
|
@ -297,7 +310,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
|
||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
||||
embeds = sentence_embeddings[0]
|
||||
return backend_pb2.EmbeddingResult(embeddings=embeds)
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
set_seed(request.Seed)
|
||||
|
|
|
@ -4,3 +4,4 @@ accelerate
|
|||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==3.3.1
|
||||
|
|
|
@ -5,3 +5,4 @@ accelerate
|
|||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==3.3.1
|
||||
|
|
|
@ -4,3 +4,4 @@ llvmlite==0.43.0
|
|||
transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==3.3.1
|
||||
|
|
|
@ -5,3 +5,5 @@ transformers
|
|||
llvmlite==0.43.0
|
||||
bitsandbytes
|
||||
outetts
|
||||
bitsandbytes
|
||||
sentence-transformers==3.3.1
|
||||
|
|
|
@ -7,3 +7,4 @@ llvmlite==0.43.0
|
|||
intel-extension-for-transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==3.3.1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue