diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index 4c862f2e4..bcba5060b 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -35,6 +35,7 @@ class WholeFileCoder(Coder): output = [] lines = content.splitlines(keepends=True) + fname = None new_lines = [] for i, line in enumerate(lines): @@ -65,16 +66,16 @@ class WholeFileCoder(Coder): # starting a new block if i == 0: - raise ValueError("No filename provided before ``` block") - - fname = lines[i - 1].strip() - if fname not in chat_files: if len(chat_files) == 1: - fname = list(chat_files)[0] + fname = chat_files[0] else: - show_chat_files = " ".join(chat_files) - # TODO: adopt the new allowed_to_edit() - raise ValueError(f"{fname} is not one of: {show_chat_files}") + # TODO: sense which file it is by diff size + raise ValueError("No filename provided before ``` block") + else: + fname = lines[i - 1].strip() + + if mode == "update" and not self.allowed_to_edit(fname): + raise ValueError(f"{fname} is not one of: {show_chat_files}") elif fname: new_lines.append(line) diff --git a/tests/test_wholefile.py b/tests/test_wholefile.py index e69de29bb..2c71407fd 100644 --- a/tests/test_wholefile.py +++ b/tests/test_wholefile.py @@ -0,0 +1,36 @@ +import os +from pathlib import Path +import tempfile +import unittest + +from aider.coders.wholefile_coder import WholeFileCoder +from aider.io import InputOutput + +class TestWholeFileCoder(unittest.TestCase): + def test_update_files(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a sample file in the temporary directory + sample_file = os.path.join(temp_dir, "sample.txt") + with open(sample_file, "w") as f: + f.write("Original content\n") + + # Initialize WholeFileCoder with the temporary directory + io = InputOutput() + coder = WholeFileCoder(root=temp_dir, io=io) + + # Set the partial response content with the updated content + coder.partial_response_content = f"{sample_file}\n```\nUpdated content\n```" + + # Call update_files method + edited_files = coder.update_files() + + # Check if the sample file was updated + self.assertIn("sample.txt", edited_files) + + # Check if the content of the sample file was updated + with open(sample_file, "r") as f: + updated_content = f.read() + self.assertEqual(updated_content, "Updated content\n") + +if __name__ == "__main__": + unittest.main()