diff --git a/tests/test_coder.py b/tests/test_coder.py index e15870b84..fbfa8417d 100644 --- a/tests/test_coder.py +++ b/tests/test_coder.py @@ -53,7 +53,7 @@ class TestCoder(unittest.TestCase): # Mock the git repo mock = MagicMock() mock.return_value = set(["file1.txt", "file2.py"]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock # Call the check_for_file_mentions method coder.check_for_file_mentions("Please check file1.txt and file2.py") @@ -97,7 +97,7 @@ class TestCoder(unittest.TestCase): mock = MagicMock() mock.return_value = set(["file1.txt", "file2.py"]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock # Call the check_for_file_mentions method coder.check_for_file_mentions("Please check file1.txt and file2.py") @@ -128,7 +128,7 @@ class TestCoder(unittest.TestCase): mock = MagicMock() mock.return_value = set([str(fname), str(other_fname)]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock # Call the check_for_file_mentions method coder.check_for_file_mentions(f"Please check {fname}!") @@ -146,7 +146,7 @@ class TestCoder(unittest.TestCase): mock = MagicMock() mock.return_value = set([str(fname)]) - coder.get_tracked_files = mock + coder.repo.get_tracked_files = mock dump(fname) # Call the check_for_file_mentions method @@ -154,26 +154,6 @@ class TestCoder(unittest.TestCase): self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) - def test_get_commit_message_no_strip_unmatched_quotes(self): - # Mock the IO object - mock_io = MagicMock() - - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(models.GPT4, None, mock_io) - - # Mock the send method to set partial_response_content and return False - def mock_send(*args, **kwargs): - coder.partial_response_content = 'a good "commit message"' - return False - - coder.send = MagicMock(side_effect=mock_send) - - # Call the get_commit_message method with dummy diff and context - result = coder.get_commit_message("dummy diff", "dummy context") - - # Assert that the returned message is the expected one - self.assertEqual(result, 'a good "commit message"') - def test_run_with_file_deletion(self): # Create a few temporary files diff --git a/tests/test_repo.py b/tests/test_repo.py index f711527b7..de59a4fe4 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -51,6 +51,25 @@ class TestRepo(unittest.TestCase): # Assert that the returned message is the expected one self.assertEqual(result, "a good commit message") + @patch("aider.repo.send_with_retries") + def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send): + # Set the return value of the mocked function + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = 'a good "commit message"' + mock_send.return_value = ( + None, + mock_response + ) + + repo = AiderRepo(InputOutput(), None) + # Call the get_commit_message method with dummy diff and context + result = repo.get_commit_message("dummy diff", "dummy context") + + # Assert that the returned message is the expected one + self.assertEqual(result, 'a good "commit message"') + + def test_get_tracked_files(self): # Create a temporary directory tempdir = Path(tempfile.mkdtemp())