keep track of the files which need dirty commits

This commit is contained in:
Paul Gauthier 2023-08-18 09:50:18 -07:00
parent 752e47a886
commit 7af82ba185
2 changed files with 7 additions and 13 deletions

View file

@ -48,7 +48,6 @@ class Coder:
total_cost = 0.0
num_exhausted_context_windows = 0
last_keyboard_interrupt = None
need_commit_before_edits = False
@classmethod
def create(
@ -106,6 +105,7 @@ class Coder:
self.chat_completion_call_hashes = []
self.chat_completion_response_hashes = []
self.need_commit_before_edits = set()
self.verbose = verbose
self.abs_fnames = set()
@ -720,8 +720,7 @@ class Coder:
return
self.io.tool_output(f"Committing {path} before applying edits.")
self.repo.repo.git.add(path)
self.need_commit_before_edits = True
self.need_commit_before_edits.add(path)
return
def allowed_to_edit(self, path):
@ -772,7 +771,7 @@ class Coder:
res = []
seen = dict()
self.need_commit_before_edits = False
self.need_commit_before_edits = set()
for edit in edits:
path = edit[0]
@ -785,9 +784,8 @@ class Coder:
if allowed:
res.append(edit)
fnames = [edit[0] for edit in res]
self.dirty_commit(fnames)
self.need_commit_before_edits = False
self.dirty_commit()
self.need_commit_before_edits = set()
return res
@ -886,14 +884,14 @@ class Coder:
self.io.tool_output("No changes made to git tracked files.")
return self.gpt_prompts.files_content_gpt_no_edits
def dirty_commit(self, fnames):
def dirty_commit(self):
if not self.need_commit_before_edits:
return
if not self.dirty_commits:
return
if not self.repo:
return
self.repo.commit(fnames=fnames)
self.repo.commit(fnames=self.need_commit_before_edits)
# files changed, move cur messages back behind the files messages
self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)

View file

@ -227,7 +227,6 @@ class TestCoder(unittest.TestCase):
mock.return_value = set([str(fname)])
coder.repo.get_tracked_files = mock
dump(fname)
# Call the check_for_file_mentions method
coder.check_for_file_mentions(f"Please check `{fname}`")
@ -524,14 +523,12 @@ three
self.assertEqual(num_commits, 3)
diff = repo.git.diff(["HEAD~2", "HEAD~1"])
dump(diff)
self.assertIn("one", diff)
self.assertIn("two", diff)
self.assertNotIn("three", diff)
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
dump(saved_diffs)
diff = saved_diffs[0]
self.assertIn("one", diff)
self.assertIn("two", diff)
@ -553,7 +550,6 @@ three
self.assertNotIn("other", diff)
self.assertNotIn("OTHER", diff)
dump(saved_diffs)
self.assertEqual(len(saved_diffs), 2)