mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-29 22:20:43 +00:00
merge sentencetransformers
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
ee7904f170
commit
981310c94f
6 changed files with 34 additions and 14 deletions
|
@ -25,6 +25,8 @@ from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreame
|
||||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
import outetts
|
import outetts
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
|
@ -88,6 +90,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
self.CUDA = torch.cuda.is_available()
|
self.CUDA = torch.cuda.is_available()
|
||||||
self.OV=False
|
self.OV=False
|
||||||
self.OuteTTS=False
|
self.OuteTTS=False
|
||||||
|
self.SentenceTransformer = False
|
||||||
|
|
||||||
device_map="cpu"
|
device_map="cpu"
|
||||||
|
|
||||||
|
@ -235,6 +238,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
|
self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
|
||||||
else:
|
else:
|
||||||
self.speaker = self.interface.load_default_speaker(name=SPEAKER)
|
self.speaker = self.interface.load_default_speaker(name=SPEAKER)
|
||||||
|
elif request.Type == "SentenceTransformer":
|
||||||
|
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||||
|
self.SentenceTransformer = True
|
||||||
else:
|
else:
|
||||||
print("Automodel", file=sys.stderr)
|
print("Automodel", file=sys.stderr)
|
||||||
self.model = AutoModel.from_pretrained(model_name,
|
self.model = AutoModel.from_pretrained(model_name,
|
||||||
|
@ -286,6 +292,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
max_length = 512
|
max_length = 512
|
||||||
if request.Tokens != 0:
|
if request.Tokens != 0:
|
||||||
max_length = request.Tokens
|
max_length = request.Tokens
|
||||||
|
|
||||||
|
embeds = None
|
||||||
|
|
||||||
|
if self.SentenceTransformer:
|
||||||
|
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||||
|
embeds = self.model.encode(request.Embeddings)
|
||||||
|
else:
|
||||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||||
|
|
||||||
# Create word embeddings
|
# Create word embeddings
|
||||||
|
@ -297,7 +310,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
|
||||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
embeds = sentence_embeddings[0]
|
||||||
|
return backend_pb2.EmbeddingResult(embeddings=embeds)
|
||||||
|
|
||||||
async def _predict(self, request, context, streaming=False):
|
async def _predict(self, request, context, streaming=False):
|
||||||
set_seed(request.Seed)
|
set_seed(request.Seed)
|
||||||
|
|
|
@ -4,3 +4,4 @@ accelerate
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
|
sentence-transformers==3.3.1
|
||||||
|
|
|
@ -5,3 +5,4 @@ accelerate
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
|
sentence-transformers==3.3.1
|
||||||
|
|
|
@ -4,3 +4,4 @@ llvmlite==0.43.0
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
|
sentence-transformers==3.3.1
|
||||||
|
|
|
@ -5,3 +5,5 @@ transformers
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
|
bitsandbytes
|
||||||
|
sentence-transformers==3.3.1
|
||||||
|
|
|
@ -7,3 +7,4 @@ llvmlite==0.43.0
|
||||||
intel-extension-for-transformers
|
intel-extension-for-transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
outetts
|
outetts
|
||||||
|
sentence-transformers==3.3.1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue