From 684b0e496400ae643b732629c87921b6a5d20f30 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Thu, 11 May 2023 22:06:02 -0700 Subject: [PATCH] Provide structured errors to GPT if it generates malformed ORIG/UPD blocks --- aider/coder.py | 11 ++++++- aider/utils.py | 72 ++++++++++++++++++++++++++++++++++++++++----- tests/test_utils.py | 31 +++++++++++++++---- 3 files changed, 100 insertions(+), 14 deletions(-) diff --git a/aider/coder.py b/aider/coder.py index 78092a8b2..21e31a42b 100755 --- a/aider/coder.py +++ b/aider/coder.py @@ -248,6 +248,12 @@ class Coder: try: edited = self.update_files(content, inp) + except ValueError as err: + err = err.args[0] + self.console.print("[red]Malformed ORIGINAL/UPDATE blocks, retrying...") + self.console.print("[red]", Text(err)) + return err + except Exception as err: print(err) print() @@ -373,8 +379,11 @@ class Coder: live.stop() def update_files(self, content, inp): + # might raise ValueError for malformed ORIG/UPD blocks + edits = list(utils.find_original_update_blocks(content)) + edited = set() - for path, original, updated in utils.find_original_update_blocks(content): + for path, original, updated in edits: full_path = os.path.abspath(os.path.join(self.root, path)) if full_path not in self.abs_fnames: diff --git a/aider/utils.py b/aider/utils.py index 09e6813aa..04755ebf4 100644 --- a/aider/utils.py +++ b/aider/utils.py @@ -151,15 +151,73 @@ UPDATED = ">>>>>>> UPDATED" separators = "|".join([ORIGINAL, DIVIDER, UPDATED]) -split_re = re.compile(r"^(" + separators + r")\s*\n") +split_re = re.compile(r"^((?:" + separators + r")[ ]*\n)", re.MULTILINE | re.DOTALL) def find_original_update_blocks(content): - for match in pattern.finditer(content): - _, path, _, original, updated = match.groups() - path = path.strip() - yield path, original, updated + pieces = re.split(split_re, content) + + pieces.reverse() + processed = [] + + try: + while pieces: + cur = pieces.pop() + + if cur in (DIVIDER, UPDATED): + processed.append(cur) + raise ValueError(f"Unexpected {cur}") + + if cur.strip() != ORIGINAL: + processed.append(cur) + continue + + processed.append(cur) # original_marker + + filename = processed[-2].splitlines()[-1] + if not len(filename) or "`" in filename: + raise ValueError(f"Bad/missing filename: {filename}") + + original_text = pieces.pop() + processed.append(original_text) + + divider_marker = pieces.pop() + processed.append(divider_marker) + if divider_marker.strip() != DIVIDER: + raise ValueError(f"Expected {DIVIDER}") + + updated_text = pieces.pop() + + updated_marker = pieces.pop() + if updated_marker.strip() != UPDATED: + raise ValueError(f"Expected {UPDATED}") + + yield filename, original_text, updated_text + except ValueError as e: + processed = "".join(processed) + err = e.args[0] + raise ValueError(f"{processed}\n^^^ {err}") + except IndexError: + processed = "".join(processed) + raise ValueError(f"{processed}\n^^^ Incomplete ORIGINAL/UPDATED block.") + except Exception: + processed = "".join(processed) + raise ValueError(f"{processed}\n^^^ Error parsing ORIGINAL/UPDATED block.") -def test_find_original_update_blocks(): - pass +edit = """ +Here's the change: + +```text +foo.txt +<<<<<<< ORIGINAL +Two +======= +Tooooo +>>>>>>> UPDATED +``` + +Hope you like it! +""" +if __name__ == "__main__": + print(list(find_original_update_blocks(edit))) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5c76610c7..96c678ec5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,5 @@ import unittest -from aider.utils import replace_most_similar_chunk, strip_quoted_wrapping +from aider import utils class TestUtils(unittest.TestCase): @@ -9,7 +9,7 @@ class TestUtils(unittest.TestCase): replace = "This is a replaced text." expected_output = "This is a replaced text..\nAnother line of text.\nYet another line.\n" - result = replace_most_similar_chunk(whole, part, replace) + result = utils.replace_most_similar_chunk(whole, part, replace) self.assertEqual(result, expected_output) def test_replace_most_similar_chunk_not_perfect_match(self): @@ -18,7 +18,7 @@ class TestUtils(unittest.TestCase): replace = "This is a replaced text.\nModified line of text." expected_output = "This is a replaced text.\nModified line of text.\nYet another line." - result = replace_most_similar_chunk(whole, part, replace) + result = utils.replace_most_similar_chunk(whole, part, replace) self.assertEqual(result, expected_output) def test_strip_quoted_wrapping(self): @@ -26,21 +26,40 @@ class TestUtils(unittest.TestCase): "filename.ext\n```\nWe just want this content\nNot the filename and triple quotes\n```" ) expected_output = "We just want this content\nNot the filename and triple quotes\n" - result = strip_quoted_wrapping(input_text, "filename.ext") + result = utils.strip_quoted_wrapping(input_text, "filename.ext") self.assertEqual(result, expected_output) def test_strip_quoted_wrapping_no_filename(self): input_text = "```\nWe just want this content\nNot the triple quotes\n```" expected_output = "We just want this content\nNot the triple quotes\n" - result = strip_quoted_wrapping(input_text) + result = utils.strip_quoted_wrapping(input_text) self.assertEqual(result, expected_output) def test_strip_quoted_wrapping_no_wrapping(self): input_text = "We just want this content\nNot the triple quotes\n" expected_output = "We just want this content\nNot the triple quotes\n" - result = strip_quoted_wrapping(input_text) + result = utils.strip_quoted_wrapping(input_text) self.assertEqual(result, expected_output) + def test_find_original_update_blocks(self): + edit = """ +Here's the change: + +```text +foo.txt +<<<<<<< ORIGINAL +Two +======= +Tooooo +>>>>>>> UPDATED +``` + +Hope you like it! +""" + + edits = list(utils.find_original_update_blocks(edit)) + self.assertEqual(edits, [("foo.txt", "Two\n", "Tooooo\n")]) + if __name__ == "__main__": unittest.main()