fix: Implement line-by-line processing for SEARCH/REPLACE and shell code blocks

This commit is contained in:
Paul Gauthier (aider) 2024-08-20 16:23:07 -07:00
parent f198c4a691
commit 72bc851ac0

View file

@ -404,94 +404,69 @@ def strip_filename(filename, fence):
def find_original_update_blocks(content, fence=DEFAULT_FENCE):
# make sure we end with a newline, otherwise the regex will miss <<UPD on the last line
if not content.endswith("\n"):
content = content + "\n"
pieces = re.split(split_re, content)
pieces.reverse()
processed = []
# Keep using the same filename in cases where GPT produces an edit block
# without a filename.
lines = content.splitlines(keepends=True)
i = 0
current_filename = None
try:
while pieces:
cur = pieces.pop()
dump(repr(cur))
# Check for various shell code blocks
shell_starts = [
"```bash",
"```sh",
"```shell", # Unix-like shells
"```cmd",
"```batch", # Windows Command Prompt
"```powershell",
"```ps1", # Windows PowerShell
"```zsh", # Z shell
"```fish", # Friendly Interactive Shell
"```ksh", # Korn Shell
"```csh",
"```tcsh", # C Shell and TENEX C Shell
]
if any(cur.strip().startswith(start) for start in shell_starts):
shell_content = []
while pieces and not pieces[-1].strip().startswith("```"):
shell_content.append(pieces.pop())
if pieces and pieces[-1].strip().startswith("```"):
pieces.pop() # Remove the closing ```
shell_type = cur.strip().split("```")[1]
yield f"{shell_type}_command", "".join(shell_content)
continue
while i < len(lines):
line = lines[i]
if cur in (DIVIDER, UPDATED):
processed.append(cur)
raise ValueError(f"Unexpected {cur}")
# Check for shell code blocks
shell_starts = [
"```bash", "```sh", "```shell", "```cmd", "```batch",
"```powershell", "```ps1", "```zsh", "```fish", "```ksh",
"```csh", "```tcsh"
]
if any(line.strip().startswith(start) for start in shell_starts):
shell_type = line.strip().split("```")[1]
shell_content = []
i += 1
while i < len(lines) and not lines[i].strip().startswith("```"):
shell_content.append(lines[i])
i += 1
if i < len(lines) and lines[i].strip().startswith("```"):
i += 1 # Skip the closing ```
yield f"{shell_type}_command", "".join(shell_content)
continue
if cur.strip() != HEAD:
processed.append(cur)
continue
# Check for SEARCH/REPLACE blocks
if line.strip() == HEAD:
try:
filename = find_filename(lines[max(0, i-3):i], fence)
if not filename:
if current_filename:
filename = current_filename
else:
raise ValueError(missing_filename_err.format(fence=fence))
processed.append(cur) # original_marker
current_filename = filename
filename = find_filename(processed[-2].splitlines(), fence)
if not filename:
if current_filename:
filename = current_filename
else:
raise ValueError(missing_filename_err.format(fence=fence))
original_text = []
i += 1
while i < len(lines) and not lines[i].strip() == DIVIDER:
original_text.append(lines[i])
i += 1
current_filename = filename
if i >= len(lines) or lines[i].strip() != DIVIDER:
raise ValueError(f"Expected `{DIVIDER}`")
original_text = pieces.pop()
processed.append(original_text)
updated_text = []
i += 1
while i < len(lines) and not lines[i].strip() == UPDATED:
updated_text.append(lines[i])
i += 1
divider_marker = pieces.pop()
processed.append(divider_marker)
if divider_marker.strip() != DIVIDER:
raise ValueError(f"Expected `{DIVIDER}` not {divider_marker.strip()}")
if i >= len(lines) or lines[i].strip() != UPDATED:
raise ValueError(f"Expected `{UPDATED}`")
updated_text = pieces.pop()
processed.append(updated_text)
yield filename, "".join(original_text), "".join(updated_text)
updated_marker = pieces.pop()
processed.append(updated_marker)
if updated_marker.strip() != UPDATED:
raise ValueError(f"Expected `{UPDATED}` not `{updated_marker.strip()}")
except ValueError as e:
processed = "".join(lines[:i+1])
err = e.args[0]
raise ValueError(f"{processed}\n^^^ {err}")
yield filename, original_text, updated_text
except ValueError as e:
processed = "".join(processed)
err = e.args[0]
raise ValueError(f"{processed}\n^^^ {err}")
except IndexError:
processed = "".join(processed)
raise ValueError(f"{processed}\n^^^ Incomplete SEARCH/REPLACE block.")
except Exception:
processed = "".join(processed)
raise ValueError(f"{processed}\n^^^ Error parsing SEARCH/REPLACE block.")
i += 1
def find_filename(lines, fence):