full edit for wholefile

This commit is contained in:
Paul Gauthier 2023-07-06 11:43:20 -07:00
parent 7d3c40ea21
commit 089fa57ede

View file

@ -3,8 +3,10 @@ import shutil
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch
from aider import models from aider import models
from aider.coders import Coder
from aider.coders.wholefile_coder import WholeFileCoder from aider.coders.wholefile_coder import WholeFileCoder
from aider.io import InputOutput from aider.io import InputOutput
@ -15,10 +17,16 @@ class TestWholeFileCoder(unittest.TestCase):
self.tempdir = tempfile.mkdtemp() self.tempdir = tempfile.mkdtemp()
os.chdir(self.tempdir) os.chdir(self.tempdir)
self.patcher = patch("aider.coders.base_coder.check_model_availability")
self.mock_check = self.patcher.start()
self.mock_check.return_value = True
def tearDown(self): def tearDown(self):
os.chdir(self.original_cwd) os.chdir(self.original_cwd)
shutil.rmtree(self.tempdir, ignore_errors=True) shutil.rmtree(self.tempdir, ignore_errors=True)
self.patcher.stop()
def test_update_files(self): def test_update_files(self):
# Create a sample file in the temporary directory # Create a sample file in the temporary directory
sample_file = "sample.txt" sample_file = "sample.txt"
@ -198,6 +206,45 @@ after b
self.assertEqual(fname_a.read_text(), "after a\n") self.assertEqual(fname_a.read_text(), "after a\n")
self.assertEqual(fname_b.read_text(), "after b\n") self.assertEqual(fname_b.read_text(), "after b\n")
def test_full_edit(self):
# Create a few temporary files
_, file1 = tempfile.mkstemp()
with open(file1, "w", encoding="utf-8") as f:
f.write("one\ntwo\nthree\n")
files = [file1]
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder.create(
models.GPT4, "whole", io=InputOutput(), openai_api_key="fake_key", fnames=files
)
# no trailing newline so the response content below doesn't add ANOTHER newline
new_content = "new\ntwo\nthree"
def mock_send(*args, **kwargs):
coder.partial_response_content = f"""
Do this:
{Path(file1).name}
```
{new_content}
```
"""
coder.partial_response_function_call = dict()
coder.send = MagicMock(side_effect=mock_send)
# Call the run method with a message
coder.run(with_message="hi")
content = open(file1, encoding="utf-8").read()
# check for one trailing newline
self.assertEqual(content, new_content + "\n")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()