diff --git a/tests/test_sendchat.py b/tests/test_sendchat.py index 7bb8fcfab..460525155 100644 --- a/tests/test_sendchat.py +++ b/tests/test_sendchat.py @@ -12,12 +12,11 @@ class PrintCalled(Exception): class TestSendChat(unittest.TestCase): + @patch("litellm.completion") @patch("builtins.print") - def test_send_with_retries_rate_limit_error(self, mock_print): - mock_client = MagicMock() - + def test_send_with_retries_rate_limit_error(self, mock_print, mock_completion): # Set up the mock to raise - mock_client.chat.completions.create.side_effect = [ + mock_completion.side_effect = [ openai.RateLimitError( "rate limit exceeded", response=MagicMock(), @@ -27,20 +26,18 @@ class TestSendChat(unittest.TestCase): ] # Call the send_with_retries method - send_with_retries(mock_client, "model", ["message"], None, False) + send_with_retries("model", ["message"], None, False) mock_print.assert_called_once() - @patch("aider.sendchat.openai.ChatCompletion.create") + @patch("litellm.completion") @patch("builtins.print") - def test_send_with_retries_connection_error(self, mock_print, mock_chat_completion_create): - mock_client = MagicMock() - + def test_send_with_retries_connection_error(self, mock_print, mock_completion): # Set up the mock to raise - mock_client.chat.completions.create.side_effect = [ + mock_completion.side_effect = [ httpx.ConnectError("Connection error"), None, ] # Call the send_with_retries method - send_with_retries(mock_client, "model", ["message"], None, False) + send_with_retries("model", ["message"], None, False) mock_print.assert_called_once()