From f2a03e917d3af31487d28948900053f149d2bbdc Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Sat, 24 Jun 2023 15:13:39 -0700 Subject: [PATCH] Handle missing filename when only 1 in session --- aider/coders/wholefile_coder.py | 29 +++++++++++++++----------- tests/test_wholefile.py | 36 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/aider/coders/wholefile_coder.py b/aider/coders/wholefile_coder.py index 75df481cd..6ffabbd2d 100644 --- a/aider/coders/wholefile_coder.py +++ b/aider/coders/wholefile_coder.py @@ -2,6 +2,7 @@ from pathlib import Path from aider import diffs +from ..dump import dump # noqa: F401 from .base_coder import Coder from .wholefile_prompts import WholeFilePrompts @@ -34,8 +35,9 @@ class WholeFileCoder(Coder): fname = None new_lines = [] for i, line in enumerate(lines): + dump(repr(fname), repr(line), repr(new_lines)) if line.startswith("```"): - if fname: + if fname is not None: # ending an existing block full_path = (Path(self.root) / fname).absolute() @@ -59,22 +61,23 @@ class WholeFileCoder(Coder): new_lines = [] continue - # starting a new block - if i == 0: + # fname==None ... starting a new block + if i > 0: + fname = lines[i - 1].strip() + if not fname: # blank line? or ``` was on first line i==0 if len(chat_files) == 1: fname = chat_files[0] else: # TODO: sense which file it is by diff size raise ValueError("No filename provided before ``` block") - else: - fname = lines[i - 1].strip() - elif fname: + + elif fname is not None: new_lines.append(line) else: output.append(line) if mode == "diff": - if fname: + if fname is not None: # ending an existing block full_path = (Path(self.root) / fname).absolute() @@ -89,10 +92,12 @@ class WholeFileCoder(Coder): return "\n".join(output) - if fname and self.allowed_to_edit(fname): - edited.add(fname) - if not self.dry_run: - new_lines = "".join(new_lines) - Path(full_path).write_text(new_lines) + if fname: + full_path = self.allowed_to_edit(fname) + if full_path: + edited.add(fname) + if not self.dry_run: + new_lines = "".join(new_lines) + Path(full_path).write_text(new_lines) return edited diff --git a/tests/test_wholefile.py b/tests/test_wholefile.py index 9f77d33ee..b0c77bd98 100644 --- a/tests/test_wholefile.py +++ b/tests/test_wholefile.py @@ -68,6 +68,42 @@ class TestWholeFileCoder(unittest.TestCase): updated_content = f.read() self.assertEqual(updated_content, "Updated content\n") + def test_update_files_earlier_filename(self): + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + + sample_file = "accumulate.py" + content = ( + "def accumulate(collection, operation):\n return [operation(x) for x in" + " collection]\n" + ) + + with open(sample_file, "w") as f: + f.write("Original content\n") + + # Initialize WholeFileCoder with the temporary directory + io = InputOutput(yes=True) + coder = WholeFileCoder(main_model=models.GPT35, io=io, fnames=[sample_file]) + + # Set the partial response content with the updated content + coder.partial_response_content = ( + f"Here's the modified `{sample_file}` file that implements the `accumulate`" + f" function as per the given instructions:\n\n```\n{content}```\n\nThis" + " implementation uses a list comprehension to apply the `operation` function to" + " each element of the `collection` and returns the resulting list." + ) + + # Call update_files method + edited_files = coder.update_files() + + # Check if the sample file was updated + self.assertIn(sample_file, edited_files) + + # Check if the content of the sample file was updated + with open(sample_file, "r") as f: + updated_content = f.read() + self.assertEqual(updated_content, content) + if __name__ == "__main__": unittest.main()