diff --git a/tests/test_wholefile.py b/tests/test_wholefile.py index 0237515cf..7f2dfe385 100644 --- a/tests/test_wholefile.py +++ b/tests/test_wholefile.py @@ -3,8 +3,10 @@ import shutil import tempfile import unittest from pathlib import Path +from unittest.mock import MagicMock, patch from aider import models +from aider.coders import Coder from aider.coders.wholefile_coder import WholeFileCoder from aider.io import InputOutput @@ -15,10 +17,16 @@ class TestWholeFileCoder(unittest.TestCase): self.tempdir = tempfile.mkdtemp() os.chdir(self.tempdir) + self.patcher = patch("aider.coders.base_coder.check_model_availability") + self.mock_check = self.patcher.start() + self.mock_check.return_value = True + def tearDown(self): os.chdir(self.original_cwd) shutil.rmtree(self.tempdir, ignore_errors=True) + self.patcher.stop() + def test_update_files(self): # Create a sample file in the temporary directory sample_file = "sample.txt" @@ -198,6 +206,45 @@ after b self.assertEqual(fname_a.read_text(), "after a\n") self.assertEqual(fname_b.read_text(), "after b\n") + def test_full_edit(self): + # Create a few temporary files + _, file1 = tempfile.mkstemp() + + with open(file1, "w", encoding="utf-8") as f: + f.write("one\ntwo\nthree\n") + + files = [file1] + + # Initialize the Coder object with the mocked IO and mocked repo + coder = Coder.create( + models.GPT4, "whole", io=InputOutput(), openai_api_key="fake_key", fnames=files + ) + + # no trailing newline so the response content below doesn't add ANOTHER newline + new_content = "new\ntwo\nthree" + + def mock_send(*args, **kwargs): + coder.partial_response_content = f""" +Do this: + +{Path(file1).name} +``` +{new_content} +``` + +""" + coder.partial_response_function_call = dict() + + coder.send = MagicMock(side_effect=mock_send) + + # Call the run method with a message + coder.run(with_message="hi") + + content = open(file1, encoding="utf-8").read() + + # check for one trailing newline + self.assertEqual(content, new_content + "\n") + if __name__ == "__main__": unittest.main()