diff --git a/backend/python/transformers/test.py b/backend/python/transformers/test.py index 305b0a93..14efa6a7 100644 --- a/backend/python/transformers/test.py +++ b/backend/python/transformers/test.py @@ -133,5 +133,41 @@ class TestBackendServicer(unittest.TestCase): except Exception as err: print(err) self.fail("SoundGeneration service failed") + finally: + self.tearDown() + + def test_embed_load_model(self): + """ + This method tests if the model is loaded successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens",Type="SentenceTransformer")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_sentencetransformers_embedding(self): + """ + This method tests if the embeddings are generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens",Type="SentenceTransformer")) + self.assertTrue(response.success) + embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") + embedding_response = stub.Embedding(embedding_request) + self.assertIsNotNone(embedding_response.embeddings) + except Exception as err: + print(err) + self.fail("Embedding service failed") finally: self.tearDown() \ No newline at end of file