refactor shell_commands, so the run after autocommit

This commit is contained in:
Paul Gauthier 2024-08-22 13:27:01 -07:00
parent 20299b2927
commit 544b8dd800
2 changed files with 55 additions and 45 deletions

View file

@ -705,6 +705,7 @@ class Coder:
self.num_reflections = 0 self.num_reflections = 0
self.lint_outcome = None self.lint_outcome = None
self.test_outcome = None self.test_outcome = None
self.shell_commands = []
if self.repo: if self.repo:
self.commit_before_message.append(self.repo.get_head()) self.commit_before_message.append(self.repo.get_head())
@ -1106,6 +1107,8 @@ class Coder:
if self.reflected_message: if self.reflected_message:
return return
self.run_shell_commands()
if edited and self.auto_lint: if edited and self.auto_lint:
lint_errors = self.lint_edited(edited) lint_errors = self.lint_edited(edited)
self.auto_commit(edited) self.auto_commit(edited)
@ -1704,8 +1707,7 @@ class Coder:
try: try:
edits = self.get_edits() edits = self.get_edits()
edits = self.prepare_to_edit(edits) edits = self.prepare_to_edit(edits)
edited = set(edit[0] for edit in edits)
edited = set(edit[0] for edit in edits if edit[0])
self.apply_edits(edits) self.apply_edits(edits)
except ValueError as err: except ValueError as err:
self.num_malformed_responses += 1 self.num_malformed_responses += 1
@ -1827,3 +1829,33 @@ class Coder:
def apply_edits(self, edits): def apply_edits(self, edits):
return return
def run_shell_commands(self):
done = set()
for command in self.shell_commands:
if command in done:
continue
done.add(command)
self.handle_shell_commands(command)
def handle_shell_commands(self, commands_str):
commands = commands_str.strip().splitlines()
command_count = sum(
1 for cmd in commands if cmd.strip() and not cmd.strip().startswith("#")
)
prompt = "Run shell command?" if command_count == 1 else "Run shell commands?"
if not self.io.confirm_ask(prompt, subject="\n".join(commands), explicit_yes_required=True):
return
for command in commands:
command = command.strip()
if not command or command.startswith("#"):
continue
self.io.tool_output()
self.io.tool_output(f"Running {command}")
# Add the command to input history
self.io.add_to_input_history(f"/run {command.strip()}")
result = self.run_interactive_subprocess(command)
if result and result.stdout:
self.io.tool_output(result.stdout)

View file

@ -25,6 +25,9 @@ class EditBlockCoder(Coder):
# might raise ValueError for malformed ORIG/UPD blocks # might raise ValueError for malformed ORIG/UPD blocks
edits = list(find_original_update_blocks(content, self.fence)) edits = list(find_original_update_blocks(content, self.fence))
self.shell_commands += [edit[1] for edit in edits if edit[0] is None]
edits = [edit for edit in edits if edit[0] is not None]
return edits return edits
def run_interactive_subprocess(self, command): def run_interactive_subprocess(self, command):
@ -45,55 +48,29 @@ class EditBlockCoder(Coder):
self.io.tool_output(f"To retry and share output with the LLM: /run {command}") self.io.tool_output(f"To retry and share output with the LLM: /run {command}")
self.io.tool_output("You can find this command in your input history with up-arrow.") self.io.tool_output("You can find this command in your input history with up-arrow.")
def handle_shell_commands(self, commands_str):
commands = commands_str.strip().splitlines()
command_count = sum(
1 for cmd in commands if cmd.strip() and not cmd.strip().startswith("#")
)
prompt = "Run shell command?" if command_count == 1 else "Run shell commands?"
if not self.io.confirm_ask(prompt, subject="\n".join(commands), explicit_yes_required=True):
return
for command in commands:
command = command.strip()
if not command or command.startswith("#"):
continue
self.io.tool_output()
self.io.tool_output(f"Running {command}")
# Add the command to input history
self.io.add_to_input_history(f"/run {command.strip()}")
result = self.run_interactive_subprocess(command)
if result and result.stdout:
self.io.tool_output(result.stdout)
def apply_edits(self, edits): def apply_edits(self, edits):
failed = [] failed = []
passed = [] passed = []
for edit in edits: for edit in edits:
if edit[0] is None: path, original, updated = edit
self.handle_shell_commands(edit[1]) full_path = self.abs_root_path(path)
continue content = self.io.read_text(full_path)
else: new_content = do_replace(full_path, content, original, updated, self.fence)
path, original, updated = edit if not new_content:
full_path = self.abs_root_path(path) # try patching any of the other files in the chat
content = self.io.read_text(full_path) dump(self.abs_fnames)
new_content = do_replace(full_path, content, original, updated, self.fence) for full_path in self.abs_fnames:
if not new_content: content = self.io.read_text(full_path)
# try patching any of the other files in the chat new_content = do_replace(full_path, content, original, updated, self.fence)
dump(self.abs_fnames) if new_content:
for full_path in self.abs_fnames: break
content = self.io.read_text(full_path)
new_content = do_replace(full_path, content, original, updated, self.fence)
if new_content:
break
if new_content: if new_content:
self.io.write_text(full_path, new_content) self.io.write_text(full_path, new_content)
passed.append(edit) passed.append(edit)
else: else:
failed.append(edit) failed.append(edit)
if not failed: if not failed:
return return
@ -470,6 +447,7 @@ def find_original_update_blocks(content, fence=DEFAULT_FENCE):
i += 1 i += 1
if i < len(lines) and lines[i].strip().startswith("```"): if i < len(lines) and lines[i].strip().startswith("```"):
i += 1 # Skip the closing ``` i += 1 # Skip the closing ```
yield None, "".join(shell_content) yield None, "".join(shell_content)
continue continue