feat: cuda transformers (#1401)

* Use cuda in transformers if available

tensorflow probably needs a different check.

Signed-off-by: Erich Schubert <kno10@users.noreply.github.com>

* feat: expose CUDA at top level

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests: add to tests and create workflow for py extra backends

* doc: update note on how to use core images

---------

Signed-off-by: Erich Schubert <kno10@users.noreply.github.com>
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Erich Schubert <kno10@users.noreply.github.com>
This commit is contained in:
Ettore Di Giacinto 2023-12-08 15:45:04 +01:00 committed by GitHub
parent 3822bd2369
commit 887b3dff04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 163 additions and 11 deletions

View file

@ -31,7 +31,7 @@ class TestBackendServicer(unittest.TestCase):
"""
This method tests if the server starts up successfully
"""
time.sleep(2)
time.sleep(10)
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
@ -48,11 +48,12 @@ class TestBackendServicer(unittest.TestCase):
"""
This method tests if the model is loaded successfully
"""
time.sleep(10)
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
except Exception as err:
@ -65,11 +66,13 @@ class TestBackendServicer(unittest.TestCase):
"""
This method tests if the embeddings are generated successfully
"""
time.sleep(10)
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
print(response.message)
self.assertTrue(response.success)
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
embedding_response = stub.Embedding(embedding_request)

View file

@ -14,14 +14,27 @@ import backend_pb2
import backend_pb2_grpc
import grpc
import torch
from transformers import AutoModel
from transformers import AutoTokenizer, AutoModel
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
def mean_pooling(model_output, attention_mask):
"""
Mean pooling to get sentence embeddings. See:
https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # Sum columns
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
@ -56,9 +69,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model_name = request.Model
try:
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if request.CUDA:
try:
# TODO: also tensorflow, make configurable
import torch.cuda
if torch.cuda.is_available():
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
self.model = self.model.to("cuda")
except Exception as err:
print("Not using CUDA:", err, file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
# Replace this with your desired response
return backend_pb2.Result(message="Model loaded successfully", success=True)
@ -74,10 +97,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
An EmbeddingResult object that contains the calculated embeddings.
"""
# Implement your logic here for the Embedding service
# Replace this with your desired response
# Tokenize input
max_length = 512
if request.Tokens != 0:
max_length = request.Tokens
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
# Create word embeddings
model_output = self.model(**encoded_input)
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy()
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
sentence_embeddings = self.model.encode(request.Embeddings)
print("Embeddings:", sentence_embeddings, file=sys.stderr)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)