From 092e7f6b3c761b25b79ca9b26c26ef6aecd7cdad Mon Sep 17 00:00:00 2001 From: "Paul Gauthier (aider)" Date: Fri, 8 Nov 2024 10:01:11 -0800 Subject: [PATCH] test: Add comprehensive tests for sendchat module functionality --- tests/basic/test_sendchat.py | 69 +++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/tests/basic/test_sendchat.py b/tests/basic/test_sendchat.py index 28397e5c8..58758905f 100644 --- a/tests/basic/test_sendchat.py +++ b/tests/basic/test_sendchat.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch from aider.exceptions import LiteLLMExceptions from aider.llm import litellm -from aider.sendchat import simple_send_with_retries +from aider.sendchat import simple_send_with_retries, send_completion, CACHE class PrintCalled(Exception): @@ -11,6 +11,9 @@ class PrintCalled(Exception): class TestSendChat(unittest.TestCase): + def setUp(self): + self.mock_messages = [{"role": "user", "content": "Hello"}] + self.mock_model = "gpt-4" def test_litellm_exceptions(self): litellm_ex = LiteLLMExceptions() litellm_ex._load(strict=True) @@ -35,3 +38,67 @@ class TestSendChat(unittest.TestCase): # Call the simple_send_with_retries method simple_send_with_retries("model", ["message"]) assert mock_print.call_count == 3 + + @patch("litellm.completion") + def test_send_completion_basic(self, mock_completion): + # Setup mock response + mock_response = MagicMock() + mock_completion.return_value = mock_response + + # Test basic send_completion + hash_obj, response = send_completion( + self.mock_model, + self.mock_messages, + functions=None, + stream=False + ) + + assert response == mock_response + mock_completion.assert_called_once() + + @patch("litellm.completion") + def test_send_completion_with_functions(self, mock_completion): + mock_function = { + "name": "test_function", + "parameters": {"type": "object"} + } + + hash_obj, response = send_completion( + self.mock_model, + self.mock_messages, + functions=[mock_function], + stream=False + ) + + # Verify function was properly included in tools + called_kwargs = mock_completion.call_args.kwargs + assert "tools" in called_kwargs + assert called_kwargs["tools"][0]["function"] == mock_function + + @patch("litellm.completion") + def test_simple_send_attribute_error(self, mock_completion): + # Setup mock to raise AttributeError + mock_completion.return_value = MagicMock() + mock_completion.return_value.choices = None + + # Should return None on AttributeError + result = simple_send_with_retries(self.mock_model, self.mock_messages) + assert result is None + + @patch("litellm.completion") + @patch("builtins.print") + def test_simple_send_non_retryable_error(self, mock_print, mock_completion): + # Test with an error that shouldn't trigger retries + mock = MagicMock() + mock.status_code = 400 + + mock_completion.side_effect = litellm.InvalidRequestError( + "Invalid request", + response=mock, + llm_provider="test_provider", + model="test_model" + ) + + result = simple_send_with_retries(self.mock_model, self.mock_messages) + assert result is None + assert mock_print.call_count == 2 # Error message and description