diff --git a/backend/python/petals/test_petals.py b/backend/python/petals/test_petals.py index a0a800e1..4c156ca6 100644 --- a/backend/python/petals/test_petals.py +++ b/backend/python/petals/test_petals.py @@ -47,7 +47,8 @@ class TestBackendServicer(unittest.TestCase): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="petals-team/StableBeluga")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="bigscience/bloom-560m")) + print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: @@ -64,9 +65,9 @@ class TestBackendServicer(unittest.TestCase): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="petals-team/StableBeluga")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="bigscience/bloom-560m")) self.assertTrue(response.success) - req = backend_pb2.PredictOptions(prompt="The capital of France is") + req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: diff --git a/backend/python/vllm/test_backend_vllm.py b/backend/python/vllm/test_backend_vllm.py index 06317c73..7760f816 100644 --- a/backend/python/vllm/test_backend_vllm.py +++ b/backend/python/vllm/test_backend_vllm.py @@ -66,7 +66,7 @@ class TestBackendServicer(unittest.TestCase): stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) self.assertTrue(response.success) - req = backend_pb2.PredictOptions(prompt="The capital of France is") + req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: