do the commit before applying edits

This commit is contained in:
Paul Gauthier 2023-08-17 10:40:01 -07:00
parent 2455676a44
commit d981de5f5f
3 changed files with 61 additions and 42 deletions

View file

@ -48,6 +48,7 @@ class Coder:
total_cost = 0.0
num_exhausted_context_windows = 0
last_keyboard_interrupt = None
need_commit_before_edits = False
@classmethod
def create(
@ -728,47 +729,80 @@ class Coder:
def get_addable_relative_files(self):
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def check_for_dirty_commit(self, path):
if not self.repo:
return True
if not self.dirty_commits:
return True
if not self.repo.is_dirty(path):
return True
self.io.tool_output(f"Committing {path} before applying edits.")
self.repo.repo.git.add(path)
self.need_commit_before_edits = True
return True
def allowed_to_edit(self, path):
full_path = self.abs_root_path(path)
tracked_files = set(self.repo.get_tracked_files())
is_in_repo = path in tracked_files
if full_path in self.abs_fnames:
if self.check_for_dirty_commit(path):
return full_path
if not Path(full_path).exists():
question = f"Allow creation of new file {path}?" # noqa: E501
else:
question = f"Allow edits to {path} which was not previously provided?" # noqa: E501
if not self.io.confirm_ask(question):
self.io.tool_error(f"Skipping edit to {path}")
return
if not Path(full_path).exists() and not self.dry_run:
if not Path(full_path).exists():
if not self.io.confirm_ask(f"Allow creation of new file {path}?"):
self.io.tool_error(f"Skipping edits to {path}")
return
if not self.dry_run:
Path(full_path).parent.mkdir(parents=True, exist_ok=True)
Path(full_path).touch()
self.abs_fnames.add(full_path)
# Check if the file is already in the repo
if self.repo:
tracked_files = set(self.repo.get_tracked_files())
relative_fname = self.get_rel_fname(full_path)
if relative_fname not in tracked_files and self.io.confirm_ask(f"Add {path} to git?"):
if not self.dry_run:
if self.repo and not is_in_repo:
self.repo.repo.git.add(full_path)
self.abs_fnames.add(full_path)
return full_path
if not self.io.confirm_ask(
f"Allow edits to {path} which was not previously added to chat?"
):
self.io.tool_error(f"Skipping edits to {path}")
return
if self.repo and not is_in_repo:
self.repo.repo.git.add(full_path)
self.abs_fnames.add(full_path)
if self.check_for_dirty_commit(path):
return full_path
apply_update_errors = 0
def prepare_to_edit(self, edits):
res = []
seen = dict()
self.need_commit_before_edits = False
for edit in edits:
path = edit[0]
rest = edit[1:]
if path in seen:
full_path = seen[path]
else:
full_path = self.allowed_to_edit(path)
seen[path] = full_path
edit = [path, full_path] + list(rest)
res.append(edit)
self.dirty_commit()
self.need_commit_before_edits = False
return res
def apply_updates(self):
@ -867,6 +901,7 @@ class Coder:
return self.gpt_prompts.files_content_gpt_no_edits
def should_dirty_commit(self, inp):
return
cmds = self.commands.matching_commands(inp)
if cmds:
matching_commands, _, _ = cmds
@ -880,29 +915,13 @@ class Coder:
return True
def dirty_commit(self):
if not self.need_commit_before_edits:
return
if not self.dirty_commits:
return
if not self.repo:
return
if not self.repo.is_dirty():
return
self.io.tool_output("Git repo has uncommitted changes.")
self.repo.show_diffs(self.pretty)
self.last_asked_for_commit_time = self.get_last_modified()
res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?",
default="y",
).strip()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
return
if res.lower() in ["y", "yes"]:
message = None
else:
message = res.strip()
self.repo.commit(message=message)
self.repo.commit()
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)

View file

@ -522,7 +522,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
io.tool_output("Use /help to see in-chat commands, run with --help to see cmd line args")
coder.dirty_commit()
# coder.dirty_commit()
if args.message:
io.tool_output()

View file

@ -190,5 +190,5 @@ class GitRepo:
return res
def is_dirty(self):
return self.repo.is_dirty()
def is_dirty(self, path=None):
return self.repo.is_dirty(path=path)