This commit is contained in:
Paul Gauthier 2023-06-20 17:04:06 -07:00
parent f1350f169f
commit 5e63ce3352
3 changed files with 309 additions and 312 deletions

View file

@ -37,7 +37,6 @@ class Coder:
from . import EditBlockCoder, WholeFileCoder from . import EditBlockCoder, WholeFileCoder
if edit_format == "diff": if edit_format == "diff":
dump("here")
return EditBlockCoder(*args, **kwargs) return EditBlockCoder(*args, **kwargs)
elif edit_format == "whole": elif edit_format == "whole":
return WholeFileCoder(*args, **kwargs) return WholeFileCoder(*args, **kwargs)

View file

@ -1,8 +1,9 @@
import math
import os import os
import re
from difflib import SequenceMatcher
from pathlib import Path from pathlib import Path
from aider import utils
from ..editors import EditBlockPrompts from ..editors import EditBlockPrompts
from .base import Coder from .base import Coder
@ -17,7 +18,7 @@ class EditBlockCoder(Coder):
def update_files(self, content): def update_files(self, content):
# might raise ValueError for malformed ORIG/UPD blocks # might raise ValueError for malformed ORIG/UPD blocks
edits = list(utils.find_original_update_blocks(content)) edits = list(find_original_update_blocks(content))
edited = set() edited = set()
for path, original, updated in edits: for path, original, updated in edits:
@ -50,7 +51,7 @@ class EditBlockCoder(Coder):
self.repo.git.add(full_path) self.repo.git.add(full_path)
edited.add(path) edited.add(path)
if utils.do_replace(full_path, original, updated, self.dry_run): if do_replace(full_path, original, updated, self.dry_run):
if self.dry_run: if self.dry_run:
self.io.tool_output(f"Dry run, did not apply edit to {path}") self.io.tool_output(f"Dry run, did not apply edit to {path}")
else: else:
@ -59,3 +60,307 @@ class EditBlockCoder(Coder):
self.io.tool_error(f"Failed to apply edit to {path}") self.io.tool_error(f"Failed to apply edit to {path}")
return edited return edited
def try_dotdotdots(whole, part, replace):
"""
See if the edit block has ... lines.
If not, return none.
If yes, try and do a perfect edit with the ... chunks.
If there's a mismatch or otherwise imperfect edit, raise ValueError.
If perfect edit succeeds, return the updated whole.
"""
dots_re = re.compile(r"(^\s*\.\.\.\n)", re.MULTILINE | re.DOTALL)
part_pieces = re.split(dots_re, part)
replace_pieces = re.split(dots_re, replace)
if len(part_pieces) != len(replace_pieces):
raise ValueError("Unpaired ... in edit block")
if len(part_pieces) == 1:
# no dots in this edit block, just return None
return
# Compare odd strings in part_pieces and replace_pieces
all_dots_match = all(part_pieces[i] == replace_pieces[i] for i in range(1, len(part_pieces), 2))
if not all_dots_match:
raise ValueError("Unmatched ... in edit block")
part_pieces = [part_pieces[i] for i in range(0, len(part_pieces), 2)]
replace_pieces = [replace_pieces[i] for i in range(0, len(replace_pieces), 2)]
pairs = zip(part_pieces, replace_pieces)
for part, replace in pairs:
if not part and not replace:
continue
if not part and replace:
if not whole.endswith("\n"):
whole += "\n"
whole += replace
continue
if whole.count(part) != 1:
raise ValueError(
"No perfect matching chunk in edit block with ... or part appears more than once"
)
whole = whole.replace(part, replace, 1)
return whole
def replace_part_with_missing_leading_whitespace(whole, part, replace):
whole_lines = whole.splitlines()
part_lines = part.splitlines()
replace_lines = replace.splitlines()
# If all lines in the part start with whitespace, then honor it.
# But GPT often outdents the part and replace blocks completely,
# thereby discarding the actual leading whitespace in the file.
if all((len(pline) > 0 and pline[0].isspace()) for pline in part_lines):
return
for i in range(len(whole_lines) - len(part_lines) + 1):
leading_whitespace = ""
for j, c in enumerate(whole_lines[i]):
if c == part_lines[0][0]:
leading_whitespace = whole_lines[i][:j]
break
if not leading_whitespace or not all(c.isspace() for c in leading_whitespace):
continue
matched = all(
whole_lines[i + k].startswith(leading_whitespace + part_lines[k])
for k in range(len(part_lines))
)
if matched:
replace_lines = [
leading_whitespace + rline if rline else rline for rline in replace_lines
]
whole_lines = whole_lines[:i] + replace_lines + whole_lines[i + len(part_lines) :]
return "\n".join(whole_lines) + "\n"
return None
def replace_most_similar_chunk(whole, part, replace):
res = replace_part_with_missing_leading_whitespace(whole, part, replace)
if res:
return res
if part in whole:
return whole.replace(part, replace)
try:
res = try_dotdotdots(whole, part, replace)
except ValueError:
return
if res:
return res
similarity_thresh = 0.8
max_similarity = 0
most_similar_chunk_start = -1
most_similar_chunk_end = -1
whole_lines = whole.splitlines()
part_lines = part.splitlines()
scale = 0.1
min_len = math.floor(len(part_lines) * (1 - scale))
max_len = math.ceil(len(part_lines) * (1 + scale))
for length in range(min_len, max_len):
for i in range(len(whole_lines) - length + 1):
chunk = whole_lines[i : i + length]
chunk = "\n".join(chunk)
similarity = SequenceMatcher(None, chunk, part).ratio()
if similarity > max_similarity and similarity:
max_similarity = similarity
most_similar_chunk_start = i
most_similar_chunk_end = i + length
if max_similarity < similarity_thresh:
return
replace_lines = replace.splitlines()
modified_whole = (
whole_lines[:most_similar_chunk_start]
+ replace_lines
+ whole_lines[most_similar_chunk_end:]
)
modified_whole = "\n".join(modified_whole)
if whole.endswith("\n"):
modified_whole += "\n"
return modified_whole
def strip_quoted_wrapping(res, fname=None):
"""
Given an input string which may have extra "wrapping" around it, remove the wrapping.
For example:
filename.ext
```
We just want this content
Not the filename and triple quotes
```
"""
if not res:
return res
res = res.splitlines()
if fname and res[0].strip().endswith(Path(fname).name):
res = res[1:]
if res[0].startswith("```") and res[-1].startswith("```"):
res = res[1:-1]
res = "\n".join(res)
if res and res[-1] != "\n":
res += "\n"
return res
def do_replace(fname, before_text, after_text, dry_run=False):
before_text = strip_quoted_wrapping(before_text, fname)
after_text = strip_quoted_wrapping(after_text, fname)
fname = Path(fname)
# does it want to make a new file?
if not fname.exists() and not before_text.strip():
fname.touch()
content = fname.read_text()
if not before_text.strip():
# append to existing file, or start a new file
new_content = content + after_text
else:
new_content = replace_most_similar_chunk(content, before_text, after_text)
if not new_content:
return
if not dry_run:
fname.write_text(new_content)
return True
ORIGINAL = "<<<<<<< ORIGINAL"
DIVIDER = "======="
UPDATED = ">>>>>>> UPDATED"
separators = "|".join([ORIGINAL, DIVIDER, UPDATED])
split_re = re.compile(r"^((?:" + separators + r")[ ]*\n)", re.MULTILINE | re.DOTALL)
def find_original_update_blocks(content):
# make sure we end with a newline, otherwise the regex will miss <<UPD on the last line
if not content.endswith("\n"):
content = content + "\n"
pieces = re.split(split_re, content)
pieces.reverse()
processed = []
# Keep using the same filename in cases where GPT produces an edit block
# without a filename.
current_filename = None
try:
while pieces:
cur = pieces.pop()
if cur in (DIVIDER, UPDATED):
processed.append(cur)
raise ValueError(f"Unexpected {cur}")
if cur.strip() != ORIGINAL:
processed.append(cur)
continue
processed.append(cur) # original_marker
filename = processed[-2].splitlines()[-1].strip()
try:
if not len(filename) or "`" in filename:
filename = processed[-2].splitlines()[-2].strip()
if not len(filename) or "`" in filename:
if current_filename:
filename = current_filename
else:
raise ValueError(
f"Bad/missing filename. It should go right above {ORIGINAL}"
)
except IndexError:
if current_filename:
filename = current_filename
else:
raise ValueError(f"Bad/missing filename. It should go right above {ORIGINAL}")
current_filename = filename
original_text = pieces.pop()
processed.append(original_text)
divider_marker = pieces.pop()
processed.append(divider_marker)
if divider_marker.strip() != DIVIDER:
raise ValueError(f"Expected {DIVIDER}")
updated_text = pieces.pop()
processed.append(updated_text)
updated_marker = pieces.pop()
processed.append(updated_marker)
if updated_marker.strip() != UPDATED:
raise ValueError(f"Expected {UPDATED}")
yield filename, original_text, updated_text
except ValueError as e:
processed = "".join(processed)
err = e.args[0]
raise ValueError(f"{processed}\n^^^ {err}")
except IndexError:
processed = "".join(processed)
raise ValueError(f"{processed}\n^^^ Incomplete ORIGINAL/UPDATED block.")
except Exception:
processed = "".join(processed)
raise ValueError(f"{processed}\n^^^ Error parsing ORIGINAL/UPDATED block.")
if __name__ == "__main__":
edit = """
Here's the change:
```text
foo.txt
<<<<<<< ORIGINAL
Two
=======
Tooooo
>>>>>>> UPDATED
```
Hope you like it!
"""
print(list(find_original_update_blocks(edit)))

View file

@ -1,159 +1,8 @@
import math
import re
from difflib import SequenceMatcher
from pathlib import Path from pathlib import Path
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
def try_dotdotdots(whole, part, replace):
"""
See if the edit block has ... lines.
If not, return none.
If yes, try and do a perfect edit with the ... chunks.
If there's a mismatch or otherwise imperfect edit, raise ValueError.
If perfect edit succeeds, return the updated whole.
"""
dots_re = re.compile(r"(^\s*\.\.\.\n)", re.MULTILINE | re.DOTALL)
part_pieces = re.split(dots_re, part)
replace_pieces = re.split(dots_re, replace)
if len(part_pieces) != len(replace_pieces):
raise ValueError("Unpaired ... in edit block")
if len(part_pieces) == 1:
# no dots in this edit block, just return None
return
# Compare odd strings in part_pieces and replace_pieces
all_dots_match = all(part_pieces[i] == replace_pieces[i] for i in range(1, len(part_pieces), 2))
if not all_dots_match:
raise ValueError("Unmatched ... in edit block")
part_pieces = [part_pieces[i] for i in range(0, len(part_pieces), 2)]
replace_pieces = [replace_pieces[i] for i in range(0, len(replace_pieces), 2)]
pairs = zip(part_pieces, replace_pieces)
for part, replace in pairs:
if not part and not replace:
continue
if not part and replace:
if not whole.endswith("\n"):
whole += "\n"
whole += replace
continue
if whole.count(part) != 1:
raise ValueError(
"No perfect matching chunk in edit block with ... or part appears more than once"
)
whole = whole.replace(part, replace, 1)
return whole
def replace_part_with_missing_leading_whitespace(whole, part, replace):
whole_lines = whole.splitlines()
part_lines = part.splitlines()
replace_lines = replace.splitlines()
# If all lines in the part start with whitespace, then honor it.
# But GPT often outdents the part and replace blocks completely,
# thereby discarding the actual leading whitespace in the file.
if all((len(pline) > 0 and pline[0].isspace()) for pline in part_lines):
return
for i in range(len(whole_lines) - len(part_lines) + 1):
leading_whitespace = ""
for j, c in enumerate(whole_lines[i]):
if c == part_lines[0][0]:
leading_whitespace = whole_lines[i][:j]
break
if not leading_whitespace or not all(c.isspace() for c in leading_whitespace):
continue
matched = all(
whole_lines[i + k].startswith(leading_whitespace + part_lines[k])
for k in range(len(part_lines))
)
if matched:
replace_lines = [
leading_whitespace + rline if rline else rline for rline in replace_lines
]
whole_lines = whole_lines[:i] + replace_lines + whole_lines[i + len(part_lines) :]
return "\n".join(whole_lines) + "\n"
return None
def replace_most_similar_chunk(whole, part, replace):
res = replace_part_with_missing_leading_whitespace(whole, part, replace)
if res:
return res
if part in whole:
return whole.replace(part, replace)
try:
res = try_dotdotdots(whole, part, replace)
except ValueError:
return
if res:
return res
similarity_thresh = 0.8
max_similarity = 0
most_similar_chunk_start = -1
most_similar_chunk_end = -1
whole_lines = whole.splitlines()
part_lines = part.splitlines()
scale = 0.1
min_len = math.floor(len(part_lines) * (1 - scale))
max_len = math.ceil(len(part_lines) * (1 + scale))
for length in range(min_len, max_len):
for i in range(len(whole_lines) - length + 1):
chunk = whole_lines[i : i + length]
chunk = "\n".join(chunk)
similarity = SequenceMatcher(None, chunk, part).ratio()
if similarity > max_similarity and similarity:
max_similarity = similarity
most_similar_chunk_start = i
most_similar_chunk_end = i + length
if max_similarity < similarity_thresh:
return
replace_lines = replace.splitlines()
modified_whole = (
whole_lines[:most_similar_chunk_start]
+ replace_lines
+ whole_lines[most_similar_chunk_end:]
)
modified_whole = "\n".join(modified_whole)
if whole.endswith("\n"):
modified_whole += "\n"
return modified_whole
def quoted_file(fname, display_fname, number=False): def quoted_file(fname, display_fname, number=False):
prompt = "\n" prompt = "\n"
prompt += display_fname prompt += display_fname
@ -169,60 +18,6 @@ def quoted_file(fname, display_fname, number=False):
return prompt return prompt
def strip_quoted_wrapping(res, fname=None):
"""
Given an input string which may have extra "wrapping" around it, remove the wrapping.
For example:
filename.ext
```
We just want this content
Not the filename and triple quotes
```
"""
if not res:
return res
res = res.splitlines()
if fname and res[0].strip().endswith(Path(fname).name):
res = res[1:]
if res[0].startswith("```") and res[-1].startswith("```"):
res = res[1:-1]
res = "\n".join(res)
if res and res[-1] != "\n":
res += "\n"
return res
def do_replace(fname, before_text, after_text, dry_run=False):
before_text = strip_quoted_wrapping(before_text, fname)
after_text = strip_quoted_wrapping(after_text, fname)
fname = Path(fname)
# does it want to make a new file?
if not fname.exists() and not before_text.strip():
fname.touch()
content = fname.read_text()
if not before_text.strip():
# append to existing file, or start a new file
new_content = content + after_text
else:
new_content = replace_most_similar_chunk(content, before_text, after_text)
if not new_content:
return
if not dry_run:
fname.write_text(new_content)
return True
def show_messages(messages, title=None): def show_messages(messages, title=None):
if title: if title:
print(title.upper(), "*" * 50) print(title.upper(), "*" * 50)
@ -232,105 +27,3 @@ def show_messages(messages, title=None):
content = msg["content"].splitlines() content = msg["content"].splitlines()
for line in content: for line in content:
print(role, line) print(role, line)
ORIGINAL = "<<<<<<< ORIGINAL"
DIVIDER = "======="
UPDATED = ">>>>>>> UPDATED"
separators = "|".join([ORIGINAL, DIVIDER, UPDATED])
split_re = re.compile(r"^((?:" + separators + r")[ ]*\n)", re.MULTILINE | re.DOTALL)
def find_original_update_blocks(content):
# make sure we end with a newline, otherwise the regex will miss <<UPD on the last line
if not content.endswith("\n"):
content = content + "\n"
pieces = re.split(split_re, content)
pieces.reverse()
processed = []
# Keep using the same filename in cases where GPT produces an edit block
# without a filename.
current_filename = None
try:
while pieces:
cur = pieces.pop()
if cur in (DIVIDER, UPDATED):
processed.append(cur)
raise ValueError(f"Unexpected {cur}")
if cur.strip() != ORIGINAL:
processed.append(cur)
continue
processed.append(cur) # original_marker
filename = processed[-2].splitlines()[-1].strip()
try:
if not len(filename) or "`" in filename:
filename = processed[-2].splitlines()[-2].strip()
if not len(filename) or "`" in filename:
if current_filename:
filename = current_filename
else:
raise ValueError(
f"Bad/missing filename. It should go right above {ORIGINAL}"
)
except IndexError:
if current_filename:
filename = current_filename
else:
raise ValueError(f"Bad/missing filename. It should go right above {ORIGINAL}")
current_filename = filename
original_text = pieces.pop()
processed.append(original_text)
divider_marker = pieces.pop()
processed.append(divider_marker)
if divider_marker.strip() != DIVIDER:
raise ValueError(f"Expected {DIVIDER}")
updated_text = pieces.pop()
processed.append(updated_text)
updated_marker = pieces.pop()
processed.append(updated_marker)
if updated_marker.strip() != UPDATED:
raise ValueError(f"Expected {UPDATED}")
yield filename, original_text, updated_text
except ValueError as e:
processed = "".join(processed)
err = e.args[0]
raise ValueError(f"{processed}\n^^^ {err}")
except IndexError:
processed = "".join(processed)
raise ValueError(f"{processed}\n^^^ Incomplete ORIGINAL/UPDATED block.")
except Exception:
processed = "".join(processed)
raise ValueError(f"{processed}\n^^^ Error parsing ORIGINAL/UPDATED block.")
if __name__ == "__main__":
edit = """
Here's the change:
```text
foo.txt
<<<<<<< ORIGINAL
Two
=======
Tooooo
>>>>>>> UPDATED
```
Hope you like it!
"""
print(list(find_original_update_blocks(edit)))