feat: Allow flexible matching of 5-9 characters in SEARCH/REPLACE block prefixes

This commit is contained in:
Paul Gauthier (aider) 2024-09-20 13:44:02 -07:00
parent 230ec50209
commit 7fa1620f58

View file

@ -365,9 +365,13 @@ def do_replace(fname, content, before_text, after_text, fence=None):
return new_content
HEAD = "<<<<<<< SEARCH"
DIVIDER = "======="
UPDATED = ">>>>>>> REPLACE"
HEAD = r"<{5,9} SEARCH"
DIVIDER = r"={5,9}"
UPDATED = r">{5,9} REPLACE"
HEAD_ERR = "<<<<<<< SEARCH"
DIVIDER_ERR = "======="
UPDATED_ERR = ">>>>>>> REPLACE"
separators = "|".join([HEAD, DIVIDER, UPDATED])
@ -407,6 +411,10 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE, valid_fnames=None)
i = 0
current_filename = None
head_pattern = re.compile(HEAD)
divider_pattern = re.compile(DIVIDER)
updated_pattern = re.compile(UPDATED)
while i < len(lines):
line = lines[i]
@ -425,7 +433,7 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE, valid_fnames=None)
"```csh",
"```tcsh",
]
next_is_editblock = i + 1 < len(lines) and lines[i + 1].rstrip() == HEAD
next_is_editblock = i + 1 < len(lines) and head_pattern.match(lines[i + 1].strip())
if any(line.strip().startswith(start) for start in shell_starts) and not next_is_editblock:
shell_content = []
@ -440,10 +448,10 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE, valid_fnames=None)
continue
# Check for SEARCH/REPLACE blocks
if line.strip() == HEAD:
if head_pattern.match(line.strip()):
try:
# if next line after HEAD exists and is DIVIDER, it's a new file
if i + 1 < len(lines) and lines[i + 1].strip() == DIVIDER:
if i + 1 < len(lines) and divider_pattern.match(lines[i + 1].strip()):
filename = find_filename(lines[max(0, i - 3) : i], fence, None)
else:
filename = find_filename(
@ -460,21 +468,21 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE, valid_fnames=None)
original_text = []
i += 1
while i < len(lines) and not lines[i].strip() == DIVIDER:
while i < len(lines) and not divider_pattern.match(lines[i].strip()):
original_text.append(lines[i])
i += 1
if i >= len(lines) or lines[i].strip() != DIVIDER:
raise ValueError(f"Expected `{DIVIDER}`")
if i >= len(lines) or not divider_pattern.match(lines[i].strip()):
raise ValueError(f"Expected `{DIVIDER_ERR}`")
updated_text = []
i += 1
while i < len(lines) and not lines[i].strip() in (UPDATED, DIVIDER):
while i < len(lines) and not (updated_pattern.match(lines[i].strip()) or divider_pattern.match(lines[i].strip())):
updated_text.append(lines[i])
i += 1
if i >= len(lines) or lines[i].strip() not in (UPDATED, DIVIDER):
raise ValueError(f"Expected `{UPDATED}` or `{DIVIDER}`")
if i >= len(lines) or not (updated_pattern.match(lines[i].strip()) or divider_pattern.match(lines[i].strip())):
raise ValueError(f"Expected `{UPDATED_ERR}` or `{DIVIDER_ERR}`")
yield filename, "".join(original_text), "".join(updated_text)