fix: Implement line-by-line processing for SEARCH/REPLACE and shell code blocks

This commit is contained in:
Paul Gauthier (aider) 2024-08-20 16:23:07 -07:00
parent f198c4a691
commit 72bc851ac0

View file

@ -404,59 +404,35 @@ def strip_filename(filename, fence):
def find_original_update_blocks(content, fence=DEFAULT_FENCE): def find_original_update_blocks(content, fence=DEFAULT_FENCE):
# make sure we end with a newline, otherwise the regex will miss <<UPD on the last line lines = content.splitlines(keepends=True)
if not content.endswith("\n"): i = 0
content = content + "\n"
pieces = re.split(split_re, content)
pieces.reverse()
processed = []
# Keep using the same filename in cases where GPT produces an edit block
# without a filename.
current_filename = None current_filename = None
try:
while pieces:
cur = pieces.pop()
dump(repr(cur))
# Check for various shell code blocks while i < len(lines):
line = lines[i]
# Check for shell code blocks
shell_starts = [ shell_starts = [
"```bash", "```bash", "```sh", "```shell", "```cmd", "```batch",
"```sh", "```powershell", "```ps1", "```zsh", "```fish", "```ksh",
"```shell", # Unix-like shells "```csh", "```tcsh"
"```cmd",
"```batch", # Windows Command Prompt
"```powershell",
"```ps1", # Windows PowerShell
"```zsh", # Z shell
"```fish", # Friendly Interactive Shell
"```ksh", # Korn Shell
"```csh",
"```tcsh", # C Shell and TENEX C Shell
] ]
if any(cur.strip().startswith(start) for start in shell_starts): if any(line.strip().startswith(start) for start in shell_starts):
shell_type = line.strip().split("```")[1]
shell_content = [] shell_content = []
while pieces and not pieces[-1].strip().startswith("```"): i += 1
shell_content.append(pieces.pop()) while i < len(lines) and not lines[i].strip().startswith("```"):
if pieces and pieces[-1].strip().startswith("```"): shell_content.append(lines[i])
pieces.pop() # Remove the closing ``` i += 1
shell_type = cur.strip().split("```")[1] if i < len(lines) and lines[i].strip().startswith("```"):
i += 1 # Skip the closing ```
yield f"{shell_type}_command", "".join(shell_content) yield f"{shell_type}_command", "".join(shell_content)
continue continue
if cur in (DIVIDER, UPDATED): # Check for SEARCH/REPLACE blocks
processed.append(cur) if line.strip() == HEAD:
raise ValueError(f"Unexpected {cur}") try:
filename = find_filename(lines[max(0, i-3):i], fence)
if cur.strip() != HEAD:
processed.append(cur)
continue
processed.append(cur) # original_marker
filename = find_filename(processed[-2].splitlines(), fence)
if not filename: if not filename:
if current_filename: if current_filename:
filename = current_filename filename = current_filename
@ -465,33 +441,32 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE):
current_filename = filename current_filename = filename
original_text = pieces.pop() original_text = []
processed.append(original_text) i += 1
while i < len(lines) and not lines[i].strip() == DIVIDER:
original_text.append(lines[i])
i += 1
divider_marker = pieces.pop() if i >= len(lines) or lines[i].strip() != DIVIDER:
processed.append(divider_marker) raise ValueError(f"Expected `{DIVIDER}`")
if divider_marker.strip() != DIVIDER:
raise ValueError(f"Expected `{DIVIDER}` not {divider_marker.strip()}")
updated_text = pieces.pop() updated_text = []
processed.append(updated_text) i += 1
while i < len(lines) and not lines[i].strip() == UPDATED:
updated_text.append(lines[i])
i += 1
updated_marker = pieces.pop() if i >= len(lines) or lines[i].strip() != UPDATED:
processed.append(updated_marker) raise ValueError(f"Expected `{UPDATED}`")
if updated_marker.strip() != UPDATED:
raise ValueError(f"Expected `{UPDATED}` not `{updated_marker.strip()}") yield filename, "".join(original_text), "".join(updated_text)
yield filename, original_text, updated_text
except ValueError as e: except ValueError as e:
processed = "".join(processed) processed = "".join(lines[:i+1])
err = e.args[0] err = e.args[0]
raise ValueError(f"{processed}\n^^^ {err}") raise ValueError(f"{processed}\n^^^ {err}")
except IndexError:
processed = "".join(processed) i += 1
raise ValueError(f"{processed}\n^^^ Incomplete SEARCH/REPLACE block.")
except Exception:
processed = "".join(processed)
raise ValueError(f"{processed}\n^^^ Error parsing SEARCH/REPLACE block.")
def find_filename(lines, fence): def find_filename(lines, fence):